In [1]:
import torch
from torch.utils.data.dataloader import DataLoader, Dataset
from tqdm.auto import tqdm
import polars as pl
import numpy as np

torch.manual_seed(42)

<torch._C.Generator at 0x207bafe9910>

In [2]:
train = pl.read_csv('train.csv')
test = pl.read_csv('test.csv')

In [3]:
class RecsysDataset(Dataset):
    
    def __init__(self, df, cat=None, cat_max_val=None, num=None):
        if cat is None:
            self.cat = [f'f_{i}' for i in range(2, 42) if f'f_{i}' in df.columns]
            self.cat_max_val = {x: df[x].max() + 1 for x in self.cat}
            self.num = [f'f_{i}_click' for i in range(2, 42) if f'f_{i}' not in df.columns] +\
                        [f'f_{i}_inst' for i in range(2, 42) if f'f_{i}' not in df.columns] +\
                        [f'f_{i}' for i in range(42, 80)]
        else:
            self.cat = cat
            self.cat_max_val = cat_max_val
            self.num = num
        self.ids_values = df['f_0'].to_numpy()#[:, np.newaxis]
        self.cat_values = []
        for c in self.cat:
            self.cat_values.append(df[c].to_numpy().copy())#[:, np.newaxis])
        self.num_values = df[self.num].to_numpy().astype(np.float32)
        self.target = False
        if 'is_clicked' in df.columns:
            self.target = True
            self.y = df[['is_clicked', 'is_installed']].to_numpy()
    
    def __getitem__(self, idx):
        x = tuple(c[idx] for c in self.cat_values) + (self.num_values[idx, :],)
        #x = tuple(torch.from_numpy(v) for v in x)
        if self.target:
            return x, self.y[idx, :]
        return x
                                                      
    def __len__(self):
        return self.num_values.shape[0]
                                                    

In [4]:
ds_train = RecsysDataset(train)
ds_test = RecsysDataset(test, ds_train.cat, ds_train.cat_max_val, ds_train.num)

In [5]:
len(ds_train)

3485852

In [6]:
ds_train[0]

((125,
  615,
  5161,
  9,
  20,
  327,
  5789,
  43,
  899,
  45,
  33,
  23,
  array([ 2.62542069e-01,  2.19217807e-01,  2.19835490e-01,  2.17375889e-01,
          2.77158558e-01,  2.16766149e-01,  2.11684540e-01,  2.31442660e-01,
          2.87452847e-01,  2.19865441e-01,  2.19865441e-01,  2.19889849e-01,
          2.19668910e-01,  2.19668910e-01,  2.19668910e-01,  2.19668910e-01,
          1.48095220e-01,  1.48095220e-01,  2.21899375e-01,  3.15789223e-01,
          2.45159000e-01,  2.77863532e-01,  2.10721895e-01,  2.19767764e-01,
          2.16472372e-01,  2.19482780e-01,  2.15644464e-01,  2.18238115e-01,
          1.97120577e-01,  1.77712277e-01,  1.74018294e-01,  1.52818143e-01,
          1.83358669e-01,  1.88248664e-01,  1.65585548e-01,  2.05292925e-01,
          1.24747545e-01,  1.78326860e-01,  1.78326860e-01,  1.78322211e-01,
          1.74029246e-01,  1.74029246e-01,  1.74029246e-01,  1.74029246e-01,
          1.13302715e-01,  1.13302715e-01,  1.56892523e-01,  1.82771131e-0

In [7]:
train.head()

f_0,f_1,f_2,f_4,f_6,f_11,f_12,f_13,f_15,f_17,f_18,f_20,f_21,f_22,f_42,f_43,f_44,f_45,f_46,f_47,f_48,f_49,f_50,f_51,f_52,f_53,f_54,f_55,f_56,f_57,f_58,f_59,f_60,f_61,f_62,f_63,f_64,…,f_23_inst,f_24_click,f_24_inst,f_25_click,f_25_inst,f_26_click,f_26_inst,f_27_click,f_27_inst,f_28_click,f_28_inst,f_29_click,f_29_inst,f_30_click,f_30_inst,f_31_click,f_31_inst,f_32_click,f_32_inst,f_33_click,f_33_inst,f_34_click,f_34_inst,f_35_click,f_35_inst,f_36_click,f_36_inst,f_37_click,f_37_inst,f_38_click,f_38_inst,f_39_click,f_39_inst,f_40_click,f_40_inst,f_41_click,f_41_inst
i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
2541162,51,125,615,5161,9,20,327,5789,43,899,45,33,23,0.056146,0.115488,-0.028767,-0.031928,-0.020947,-0.034987,-0.052131,-0.040579,0.989237,-0.282259,-0.050542,-0.104387,0.653084,0.25541,-0.087799,0.684768,-0.346807,-0.05513,-0.00685,-0.10213,0.028966,0.136635,-0.023027,…,0.178327,0.219865,0.178327,0.21989,0.178322,0.219669,0.174029,0.219669,0.174029,0.219669,0.174029,0.219669,0.174029,0.148095,0.113303,0.148095,0.113303,0.221899,0.156893,0.315789,0.182771,0.245159,0.184967,0.277864,0.188139,0.210722,0.170611,0.219768,0.177842,0.216472,0.176785,0.219483,0.173818,0.215644,0.171729,0.218238,0.173167
2541260,49,135,632,5164,19,24,328,5500,48,900,54,32,23,-0.083989,-0.405498,-0.028767,-0.031928,-0.020947,-0.034987,-0.052131,-0.040579,-0.062163,0.839139,-0.125328,-0.104387,-0.14403,-0.021474,-0.170255,-0.082985,0.010644,-0.055284,-0.00685,-0.10213,-0.068463,-0.097453,-0.022173,…,0.178327,0.219865,0.178327,0.21989,0.178322,0.219669,0.174029,0.219669,0.174029,0.219669,0.174029,0.219669,0.174029,0.148095,0.113303,0.148095,0.113303,0.221899,0.156893,0.213572,0.173447,0.1919,0.16194,0.20007,0.169209,0.210722,0.170611,0.219768,0.177842,0.216472,0.176785,0.219483,0.173818,0.215644,0.171729,0.218238,0.173167
2541318,51,135,632,5164,22,22,325,5395,46,898,54,32,23,-0.100948,-0.405819,-0.028767,-0.031928,-0.020947,-0.034987,-0.052131,-0.040579,-0.062163,0.638127,-0.125328,-0.104387,-0.14403,-0.191865,-0.170255,-0.220084,0.00689,-0.05592,-0.00685,-0.075478,-0.045979,-0.097453,-0.021183,…,0.178327,0.219865,0.178327,0.21989,0.178322,0.219669,0.174029,0.219669,0.174029,0.219669,0.174029,0.219669,0.174029,0.148095,0.113303,0.148095,0.113303,0.221899,0.156893,0.213572,0.173447,0.1919,0.16194,0.20007,0.169209,0.210722,0.170611,0.219768,0.177842,0.216472,0.176785,0.219483,0.173818,0.215644,0.171729,0.218238,0.173167
2541770,64,115,600,4618,21,23,324,5631,47,900,53,33,23,0.077866,0.52032,-0.028767,-0.031928,-0.020947,-0.034987,-0.052131,-0.040579,-0.062163,-0.338632,-0.050542,0.033247,0.005429,-0.170566,-0.129027,-0.206374,-0.08047,-0.056814,-0.00685,-0.075478,-0.068463,0.312201,-0.023797,…,0.178327,0.219865,0.178327,0.21989,0.178322,0.219669,0.174029,0.219669,0.174029,0.219669,0.174029,0.219669,0.174029,0.284952,0.229542,0.285276,0.229644,0.21026,0.200825,0.213572,0.173447,0.245159,0.184967,0.277864,0.188139,0.210722,0.170611,0.219768,0.177842,0.216472,0.176785,0.219483,0.173818,0.215644,0.171729,0.218238,0.173167
2542002,62,135,631,5166,22,24,324,5330,48,899,54,32,23,-0.089047,-0.405613,-0.028767,-0.031928,-0.020947,-0.034987,-0.052131,-0.040579,-0.062163,0.318304,-0.125328,-0.104387,-0.14403,-0.191865,-0.170255,-0.220084,0.061718,-0.056104,-0.00685,-0.10213,-0.068463,-0.068192,-0.021414,…,0.178327,0.219865,0.178327,0.21989,0.178322,0.219669,0.174029,0.219669,0.174029,0.219669,0.174029,0.219669,0.174029,0.284952,0.229542,0.285276,0.229644,0.221899,0.156893,0.213572,0.173447,0.1919,0.16194,0.20007,0.169209,0.210722,0.170611,0.219768,0.177842,0.216472,0.176785,0.219483,0.173818,0.215644,0.171729,0.218238,0.173167


In [8]:
class EmbeddedFeatures(torch.nn.Module):
    
    def __init__(self, ds_train, dims=32):
        super().__init__()
        embs = []
        for c in ds_train.cat:
            m = ds_train.cat_max_val[c]
            e = torch.nn.Embedding(m, dims)
            embs.append(e)
        self.embeddings = torch.nn.ModuleList(embs)
        
    def forward(self, cats):
        embs = None
        for c, e in zip(cats, self.embeddings):
            if embs is None:
                embs = e(c)
            else:
                embs += e(c)
        embs /= len(cats)
        return embs

class DeepFeatures(torch.nn.Module):
    
    def __init__(self, ds_train, embs, depth=3, dims=32):
        super().__init__()
        self.embs = embs
        num_dims = ds_train.num_values.shape[1]
        #first 
        deep_list = [torch.nn.Linear(dims + num_dims, dims)]
        for _ in range(1, depth):
            deep_list.append(torch.nn.Linear(dims, dims))
        self.deep = torch.nn.ModuleList(deep_list)
        
        
    def forward(self, cats, nums, std=0.5):
        embs = self.embs(cats)
        x = torch.cat((embs, nums), dim=1)
        if self.training:
            x = x * (1 + std * torch.randn_like(x))
        output = []
        for l in self.deep:
            x = l(x)
            if self.training:
                x = x * (1 + std * torch.randn_like(x))
            output.append(x)
            x = torch.nn.functional.leaky_relu(x)
        return output
    

class DeepMF(torch.nn.Module):
    
    def __init__(self, ds_train, depth=3, dims=32):
        super().__init__()
        embds = EmbeddedFeatures(ds_train)
        self.base = DeepFeatures(ds_train, embds, depth=3, dims=32)
        self.click = DeepFeatures(ds_train, embds, depth=3, dims=32)
        self.install = DeepFeatures(ds_train, embds, depth=3, dims=32)
        self.multi = torch.nn.parameter.Parameter(torch.randn((1,1)))
        self.att = torch.nn.parameter.Parameter(torch.randn((depth, 2)))
        
        
    def forward(self, cats, nums):
        base = self.base(cats, nums)
        click = self.click(cats, nums)
        install = self.install(cats, nums)
        click_out = None
        install_out = None
        for e, (b, c, i) in enumerate(zip(base, click, install)):
            c_v = torch.sum(b * c, dim=1, keepdim=True) * self.multi
            i_v = torch.sum(b * i, dim=1, keepdim=True) * self.multi
            if click_out is None:
                click_out = c_v * self.att[e, 0]
                install_out = i_v * self.att[e, 1]
            else:
                click_out += (c_v * self.att[e, 0])
                install_out += (i_v  * self.att[e, 1])
        out = torch.cat((click_out, install_out), dim=1)
        return torch.nn.functional.sigmoid(out)

In [9]:
device = 'cuda'

In [10]:
model = DeepMF(ds_train).to(device)

In [11]:
dl_train = DataLoader(ds_train, batch_size=1024, shuffle=True)
dl_test = DataLoader(ds_test, batch_size=2048, shuffle=False)

In [12]:
def epoch(model, loss_f, optimizer, dl_train, device=device):
    loss = 0
    for x, y in tqdm(dl_train):
        optimizer.zero_grad()
        cats = [c.to(device) for c in x[:-1]]
        nums = x[-1].to(device)
        y = y.float().to(device)
        y_pred = model(cats, nums)
        c_loss = loss_f(y_pred, y)
        c_loss.backward()
        optimizer.step()
        loss += c_loss.cpu().item()
    return loss / len(dl_train)

In [13]:
def predict(model, dl_test, device=device):
    preds = [] 
    with torch.no_grad():
        for x in tqdm(dl_test):
            cats = [c.to(device) for c in x[:-1]]
            nums = x[-1].to(device)
            y_pred = model(cats, nums).cpu().numpy()
            preds.append(y_pred)
    preds = np.concatenate(preds, axis=0)
    return preds

In [14]:
loss_f = torch.nn.BCELoss()
optimizer = torch.optim.RAdam(model.parameters())

In [15]:
for i in range(40):
    l = epoch(model, loss_f, optimizer, dl_train)
    print(f'{i}: Current loss in training {l}')
torch.save(model.state_dict(), 'predict_deep_mf_single_embds_rnd_v2.pt')

  0%|          | 0/3405 [00:00<?, ?it/s]

0: Current loss in training 0.36439792068344207


  0%|          | 0/3405 [00:00<?, ?it/s]

1: Current loss in training 0.3252088072103901


  0%|          | 0/3405 [00:00<?, ?it/s]

2: Current loss in training 0.3210477328860637


  0%|          | 0/3405 [00:00<?, ?it/s]

3: Current loss in training 0.31902911765459874


  0%|          | 0/3405 [00:00<?, ?it/s]

4: Current loss in training 0.31748582051188934


  0%|          | 0/3405 [00:00<?, ?it/s]

5: Current loss in training 0.31630984159643255


  0%|          | 0/3405 [00:00<?, ?it/s]

6: Current loss in training 0.3153477294329505


  0%|          | 0/3405 [00:00<?, ?it/s]

7: Current loss in training 0.3147436399466841


  0%|          | 0/3405 [00:00<?, ?it/s]

8: Current loss in training 0.31390185768383716


  0%|          | 0/3405 [00:00<?, ?it/s]

9: Current loss in training 0.31337633799876413


  0%|          | 0/3405 [00:00<?, ?it/s]

10: Current loss in training 0.3129376463690518


  0%|          | 0/3405 [00:00<?, ?it/s]

11: Current loss in training 0.31249734965890164


  0%|          | 0/3405 [00:00<?, ?it/s]

12: Current loss in training 0.3119985481103261


  0%|          | 0/3405 [00:00<?, ?it/s]

13: Current loss in training 0.31170055409360037


  0%|          | 0/3405 [00:00<?, ?it/s]

14: Current loss in training 0.3114984256230437


  0%|          | 0/3405 [00:00<?, ?it/s]

15: Current loss in training 0.3110231047159775


  0%|          | 0/3405 [00:00<?, ?it/s]

16: Current loss in training 0.3107938399573955


  0%|          | 0/3405 [00:00<?, ?it/s]

17: Current loss in training 0.31055294888835294


  0%|          | 0/3405 [00:00<?, ?it/s]

18: Current loss in training 0.31038949460710197


  0%|          | 0/3405 [00:00<?, ?it/s]

19: Current loss in training 0.3100921085840105


  0%|          | 0/3405 [00:00<?, ?it/s]

20: Current loss in training 0.30990781882961066


  0%|          | 0/3405 [00:00<?, ?it/s]

21: Current loss in training 0.3097589533584584


  0%|          | 0/3405 [00:00<?, ?it/s]

22: Current loss in training 0.3096719021664142


  0%|          | 0/3405 [00:00<?, ?it/s]

23: Current loss in training 0.309499784179371


  0%|          | 0/3405 [00:00<?, ?it/s]

24: Current loss in training 0.30940517945961804


  0%|          | 0/3405 [00:00<?, ?it/s]

25: Current loss in training 0.3090825781072989


  0%|          | 0/3405 [00:00<?, ?it/s]

26: Current loss in training 0.30895812084321234


  0%|          | 0/3405 [00:00<?, ?it/s]

27: Current loss in training 0.3089413782310906


  0%|          | 0/3405 [00:00<?, ?it/s]

28: Current loss in training 0.3087900136448388


  0%|          | 0/3405 [00:00<?, ?it/s]

29: Current loss in training 0.30856192186365394


  0%|          | 0/3405 [00:00<?, ?it/s]

30: Current loss in training 0.3084784091963117


  0%|          | 0/3405 [00:00<?, ?it/s]

31: Current loss in training 0.3084454949381768


  0%|          | 0/3405 [00:00<?, ?it/s]

32: Current loss in training 0.30848171741601826


  0%|          | 0/3405 [00:00<?, ?it/s]

33: Current loss in training 0.30819874682878085


  0%|          | 0/3405 [00:00<?, ?it/s]

34: Current loss in training 0.30808244904232446


  0%|          | 0/3405 [00:00<?, ?it/s]

35: Current loss in training 0.3080151131324656


  0%|          | 0/3405 [00:00<?, ?it/s]

36: Current loss in training 0.3079715166585561


  0%|          | 0/3405 [00:00<?, ?it/s]

37: Current loss in training 0.30796448551610706


  0%|          | 0/3405 [00:00<?, ?it/s]

38: Current loss in training 0.30787411404776327


  0%|          | 0/3405 [00:00<?, ?it/s]

39: Current loss in training 0.3076431905995732


In [16]:
model.eval()
pred = predict(model, dl_test)

  0%|          | 0/79 [00:00<?, ?it/s]

In [17]:
idx = ds_test.ids_values[:, np.newaxis]
pred = np.concatenate((idx, pred), axis=-1)

In [18]:
pred = pl.DataFrame(data=pred, schema=[('RowId', pl.Int32), ('is_clicked', pl.Float32), ('is_installed', pl.Float32)])
pred.write_csv('log2_out_deep_mf_single_embds_rnd_v2.csv', separator='\t')