In [1]:
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset
import transformers
import json, pickle
from models import *
from tqdm import tqdm
from utils import mapk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class UserCourseDataset(Dataset):
    def __init__(self, users, user_embeds, total_courses=92):
        self.users = users
        self.user_embeds = [torch.from_numpy(user_embed) for user_embed in user_embeds]
        self.total_courses = total_courses

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        labels = torch.zeros(self.total_courses)
        labels = labels.scatter_(0, torch.tensor(self.users[idx]["subgroups"]), 1)
        user_embed = self.user_embeds[idx]
        return user_embed, labels

def collate_fn(batch):
    ### pad the user_embeds to the same length
    # print(batch[0][0].shape)
    max_len = max([item[0].shape[0] for item in batch])
    embed_dim = batch[0][0].shape[1]
    user_embeds = torch.stack([torch.cat([item[0], torch.zeros(max_len - item[0].shape[0], embed_dim)], dim=0) for item in batch])
    labels = torch.stack([item[1] for item in batch])
    return user_embeds, labels

In [3]:
user_embed_path = "processed_datas/combined_train_users_embeddings.pkl"
user_path = "processed_datas/combined_train_users.json"

with open(user_path, "r") as f:
    users = json.load(f)

with open(user_embed_path, "rb") as f:
    user_embeds = pickle.load(f)


trainset = UserCourseDataset(users, user_embeds)

user_embed_path = "processed_datas/combined_valid_embeddings.pkl"
user_path = "processed_datas/combined_valid.json"

with open(user_path, "r") as f:
    users = json.load(f)

with open(user_embed_path, "rb") as f:
    user_embeds = pickle.load(f)


validset = UserCourseDataset(users, user_embeds)

In [4]:
def loss_fn(logits, labels):
    return nn.BCEWithLogitsLoss(pos_weight=3.0 * torch.ones(92))(logits, labels)


def eval_metric(logits, labels, k=50):
    preds = torch.argsort(logits, dim=-1, descending=True)[:, :k]
    labels = torch.where(labels == 1)
    labels = [labels[1][labels[0] == i].tolist() for i in range(preds.shape[0])]
    return mapk(labels, preds, 50)

In [5]:
train_loader = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(validset, batch_size=32, shuffle=True, collate_fn=collate_fn)
model = ValModelv0(nhead=8, num_tokens=92, in_size=512).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5, verbose=True, mode='max')

In [13]:
k = 3
topk_scores = [0 for _ in range(k)]


for epoch in range(60):
    losses = []
    recs = []
    topks = []
    for user_embeds, labels in tqdm(train_loader):
        user_embeds = user_embeds.cuda()
        logits = model(user_embeds).cpu()
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        recs.append(torch.sigmoid(logits).sum(dim=-1).mean().item())
        topks.append(eval_metric(logits, labels))
    print("train epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))

    losses = []
    recs = []
    topks = []
    model.eval()
    with torch.no_grad():
        for user_embeds, labels in valid_loader:
            user_embeds = user_embeds.cuda()
            logits = model(user_embeds).cpu()
            loss = loss_fn(logits, labels)
            losses.append(loss.item())
            recs.append(torch.sigmoid(logits).sum(dim=-1).mean().item())
            topks.append(eval_metric(logits, labels))
    print("valid epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))
    if np.mean(topks) > min(topk_scores):
        i = topk_scores.index(min(topk_scores))
        topk_scores[i] = np.mean(topks)
        torch.save(model.state_dict(), "ckpts/val_modelv0_{}.pth".format(i))
    model.train()
    # lr_scheduler.step(np.mean(topks))

print(topk_scores)

100%|██████████| 1867/1867 [00:54<00:00, 34.12it/s]


train epoch: 0, loss: 0.2633805887371573, rec: 8.930641926988663, topk: 0.4222722120821395
valid epoch: 0, loss: 0.23291627853945063, rec: 8.337474930417407, topk: 0.30907754711022256


100%|██████████| 1867/1867 [00:54<00:00, 34.23it/s]


train epoch: 1, loss: 0.26381993047138813, rec: 8.936663070846254, topk: 0.41866477000184
valid epoch: 1, loss: 0.2328363212828453, rec: 8.489458158776001, topk: 0.31385617430533796


100%|██████████| 1867/1867 [00:54<00:00, 34.16it/s]


train epoch: 2, loss: 0.2636929632877371, rec: 8.936317476028417, topk: 0.41920584622811485
valid epoch: 2, loss: 0.23262736627033778, rec: 8.409374054971632, topk: 0.3093070320736781


100%|██████████| 1867/1867 [00:54<00:00, 34.01it/s]


train epoch: 3, loss: 0.2633291048341469, rec: 8.930079070654667, topk: 0.42011693288051305
valid epoch: 3, loss: 0.23129004080380713, rec: 8.25495588517451, topk: 0.3127525230295533


100%|██████████| 1867/1867 [00:54<00:00, 34.26it/s]


train epoch: 4, loss: 0.2631046356574443, rec: 8.928742397259311, topk: 0.4205925150371123
valid epoch: 4, loss: 0.2315530657850124, rec: 7.733442015700287, topk: 0.30835361134609357


100%|██████████| 1867/1867 [00:54<00:00, 34.00it/s]


train epoch: 5, loss: 0.26282695400312617, rec: 8.925251087531983, topk: 0.42114177957065035
valid epoch: 5, loss: 0.23204701618997606, rec: 7.863789743119543, topk: 0.30481688457283757


100%|██████████| 1867/1867 [00:54<00:00, 34.48it/s]


train epoch: 6, loss: 0.2626246217424753, rec: 8.915832558863292, topk: 0.4219026705450887
valid epoch: 6, loss: 0.23269697974671374, rec: 8.07107574075133, topk: 0.30778532907504613


100%|██████████| 1867/1867 [00:54<00:00, 34.31it/s]


train epoch: 7, loss: 0.26228795164271224, rec: 8.915373454581752, topk: 0.4226960766154043
valid epoch: 7, loss: 0.2324818234358515, rec: 8.357914171376072, topk: 0.3103036070049259


100%|██████████| 1867/1867 [00:54<00:00, 34.31it/s]


train epoch: 8, loss: 0.2621214967969272, rec: 8.912709308679435, topk: 0.42342064270222635
valid epoch: 8, loss: 0.23330078757071232, rec: 7.872859979723836, topk: 0.3029053816513018


100%|██████████| 1867/1867 [00:53<00:00, 34.88it/s]


train epoch: 9, loss: 0.2617958554560064, rec: 8.904940643474005, topk: 0.42343177414947064
valid epoch: 9, loss: 0.2313914662534064, rec: 7.880815977578635, topk: 0.31128877210448863


 48%|████▊     | 887/1867 [00:25<00:28, 34.13it/s]


KeyboardInterrupt: 

In [None]:
print(topk_scores)

[0.3180500737698619, 0.3131914555527006, 0.31731365500919856, 0.31211073845310017, 0.31571356170372056, 0.3167765206503458, 0.31362011125222683, 0.31175490433947184, 0.31295463092443715, 0.31206028410984915]


: 