In [1]:
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset
import transformers
import json, pickle
from models import *
from tqdm import tqdm
from utils import mapk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class UserCourseDataset(Dataset):
    def __init__(self, users, user_embeds, total_courses=732):
        self.users = users
        self.user_embeds = [torch.from_numpy(user_embed) for user_embed in user_embeds]
        self.total_courses = total_courses

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        labels = torch.zeros(self.total_courses)
        labels = labels.scatter_(0, torch.tensor(self.users[idx]["labels"]), 1)
        user_embed = self.user_embeds[idx]
        return user_embed, labels

def collate_fn(batch):
    ### pad the user_embeds to the same length
    # print(batch[0][0].shape)
    max_len = max([item[0].shape[0] for item in batch])
    embed_dim = batch[0][0].shape[1]
    user_embeds = torch.stack([torch.cat([item[0], torch.zeros(max_len - item[0].shape[0], embed_dim)], dim=0) for item in batch])
    labels = torch.stack([item[1] for item in batch])
    return user_embeds, labels

In [3]:
user_embed_path = "processed_datas/train_users_embeddings.pkl"
user_path = "processed_datas/train_users.json"
course_embed_path = "processed_datas/train_courses_embeddings.pkl"

with open(user_path, "r") as f:
    users = json.load(f)

with open(user_embed_path, "rb") as f:
    user_embeds = pickle.load(f)

#user_embeds = [user_embed.reshape(1, -1).astype(np.float32) for user_embed in user_embeds]

trainset = UserCourseDataset(users, user_embeds)

user_embed_path = "processed_datas/val_unseen_embeddings.pkl"
user_path = "processed_datas/val_unseen.json"

with open(user_path, "r") as f:
    users = json.load(f)

with open(user_embed_path, "rb") as f:
    user_embeds = pickle.load(f)

#user_embeds = [user_embed.reshape(1, -1).astype(np.float32) for user_embed in user_embeds]


validset = UserCourseDataset(users, user_embeds)

In [4]:
train_labels = [trainset[i][1] for i in range(len(trainset))]
train_labels_count = np.sum(train_labels, axis=0)
json.dump(np.array(train_labels_count).astype(np.int).tolist(), open("processed_datas/train_labels_count.json", "w"))
negative_weight = train_labels_count / len(trainset)

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  json.dump(np.array(train_labels_count).astype(np.int).tolist(), open("processed_datas/train_labels_count.json", "w"))


In [5]:
# def loss_fn(logits, labels, rec_num=50, rec_weight=0.001):
#     positive_mask = labels.type(torch.bool)
#     positive_loss = nn.BCELoss()(logits[positive_mask], labels[positive_mask])
#     rec_loss = nn.L1Loss()(logits.sum(dim=-1), rec_num*torch.ones(logits.shape[0]))
#     return positive_loss + rec_weight * rec_loss

def loss_fn(logits, labels):
    return nn.BCEWithLogitsLoss(pos_weight=200 * torch.ones(732))(logits, labels)


def eval_metric(logits, labels, k=50):
    preds = torch.argsort(logits, dim=-1, descending=True)[:, :k]
    labels = torch.where(labels == 1)
    labels = [labels[1][labels[0] == i].tolist() for i in range(preds.shape[0])]
    return mapk(labels, preds, 50)

In [6]:
train_loader = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(validset, batch_size=32, shuffle=True, collate_fn=collate_fn)
model = ValModelv1(nhead=8).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [7]:
best_topk = 0

for epoch in range(100):
    losses = []
    recs = []
    topks = []
    for user_embeds, labels in tqdm(train_loader):
        user_embeds = user_embeds.cuda()
        logits = model(user_embeds).cpu()
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        recs.append(torch.sigmoid(logits).sum(dim=-1).mean().item())
        topks.append(eval_metric(logits, labels))
    print("train epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))

    losses = []
    recs = []
    topks = []
    model.eval()
    with torch.no_grad():
        for user_embeds, labels in valid_loader:
            user_embeds = user_embeds.cuda()
            logits = model(user_embeds).cpu()
            loss = loss_fn(logits, labels)
            losses.append(loss.item())
            recs.append(torch.sigmoid(logits).sum(dim=-1).mean().item())
            topks.append(eval_metric(logits, labels))
    print("valid epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))
    if np.mean(topks) > best_topk:
        best_topk = np.mean(topks)
        torch.save(model.state_dict(), "val_modelv1.pth")
    model.train()

100%|██████████| 1867/1867 [01:12<00:00, 25.69it/s]


train epoch: 0, loss: 0.4033388276916171, rec: 330.6822677064721, topk: 0.030925384256908207
valid epoch: 0, loss: 0.32791201246308754, rec: 275.1023646763393, topk: 0.029817128165930153


100%|██████████| 1867/1867 [01:11<00:00, 26.01it/s]


train epoch: 1, loss: 0.37953317810267, rec: 296.19407363670933, topk: 0.04562856524530621
valid epoch: 1, loss: 0.31842441586675224, rec: 256.8728516966432, topk: 0.03935474293571417


100%|██████████| 1867/1867 [01:11<00:00, 26.02it/s]


train epoch: 2, loss: 0.37428991996992955, rec: 289.38067498638725, topk: 0.04884069129999926
valid epoch: 2, loss: 0.32483375506414164, rec: 293.71946909139444, topk: 0.03609003507125146


  6%|▌         | 112/1867 [00:04<01:08, 25.47it/s]


KeyboardInterrupt: 

In [7]:
train_loader = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(validset, batch_size=32, shuffle=True, collate_fn=collate_fn)
model = ValModelv0(nhead=8, in_size=512, dropout=0.1, hidden_size=256).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [8]:
best_score = 0


for epoch in range(30):
    losses = []
    recs = []
    topks = []
    for user_embeds, labels in tqdm(train_loader):
        user_embeds = user_embeds.cuda()
        logits = model(user_embeds).cpu()
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        recs.append(torch.sigmoid(logits).sum(dim=-1).mean().item())
        topks.append(eval_metric(logits, labels))
    print("train epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))

    losses = []
    recs = []
    topks = []
    model.eval()
    with torch.no_grad():
        for user_embeds, labels in valid_loader:
            user_embeds = user_embeds.cuda()
            logits = model(user_embeds).cpu()
            loss = loss_fn(logits, labels)
            losses.append(loss.item())
            recs.append(torch.sigmoid(logits).sum(dim=-1).mean().item())
            topks.append(eval_metric(logits, labels))
    print("valid epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))
    if np.mean(topks) > best_score:
        best_score = np.mean(topks)
        torch.save(model.state_dict(), "val_modelv0.pth")
    model.train()

100%|██████████| 1867/1867 [00:58<00:00, 32.12it/s]


train epoch: 0, loss: 0.6755380030934927, rec: 164.993839813543, topk: 0.16413441316519223
valid epoch: 0, loss: 1.276698078264247, rec: 128.60163414085304, topk: 0.0766969926483974


100%|██████████| 1867/1867 [00:57<00:00, 32.46it/s]


train epoch: 1, loss: 0.6028936509148731, rec: 142.6512653532677, topk: 0.19043870196842944
valid epoch: 1, loss: 1.4861792299773666, rec: 109.47398064162705, topk: 0.0821112780426354


100%|██████████| 1867/1867 [00:57<00:00, 32.52it/s]


train epoch: 2, loss: 0.5780242734818645, rec: 135.60534616888177, topk: 0.20315493772580492
valid epoch: 2, loss: 1.6541208261316949, rec: 111.33684789217435, topk: 0.08848893809739125


100%|██████████| 1867/1867 [00:57<00:00, 32.68it/s]


train epoch: 3, loss: 0.5644772513447819, rec: 131.93885116291096, topk: 0.20936272898262343
valid epoch: 3, loss: 1.6262995424670177, rec: 117.38719694955009, topk: 0.10342494296722145


100%|██████████| 1867/1867 [00:57<00:00, 32.61it/s]


train epoch: 4, loss: 0.553471740644503, rec: 129.17947395050916, topk: 0.2140272677124576
valid epoch: 4, loss: 1.765723467364416, rec: 113.61467977670523, topk: 0.10306487704780257


100%|██████████| 1867/1867 [00:57<00:00, 32.70it/s]


train epoch: 5, loss: 0.5453932853580223, rec: 126.91718217454368, topk: 0.2155389231163919
valid epoch: 5, loss: 2.0340358441347606, rec: 100.11638869820061, topk: 0.10479374189497277


100%|██████████| 1867/1867 [00:57<00:00, 32.59it/s]


train epoch: 6, loss: 0.5386738210166273, rec: 125.252509147264, topk: 0.2185500580371749
valid epoch: 6, loss: 2.078260803615654, rec: 109.88975007193429, topk: 0.10196218173809811


100%|██████████| 1867/1867 [00:57<00:00, 32.61it/s]


train epoch: 7, loss: 0.5324629967584288, rec: 123.85094782089128, topk: 0.21942103011115496
valid epoch: 7, loss: 2.1953829310752533, rec: 102.73514885954805, topk: 0.10570963321332695


100%|██████████| 1867/1867 [00:57<00:00, 32.68it/s]


train epoch: 8, loss: 0.5263505103470619, rec: 122.0814146420888, topk: 0.22165105730413381
valid epoch: 8, loss: 2.1891650338094313, rec: 107.83895607833024, topk: 0.10067057316625545


100%|██████████| 1867/1867 [00:57<00:00, 32.56it/s]


train epoch: 9, loss: 0.5196851408762541, rec: 120.6750876371019, topk: 0.2233213276142079
valid epoch: 9, loss: 2.289363757594601, rec: 108.31105007968107, topk: 0.10547693398331065


100%|██████████| 1867/1867 [00:57<00:00, 32.52it/s]


train epoch: 10, loss: 0.5137105180448431, rec: 119.16264700008448, topk: 0.22343351281158205
valid epoch: 10, loss: 2.3145128328066606, rec: 102.71819395547385, topk: 0.10185518094841636


100%|██████████| 1867/1867 [00:57<00:00, 32.49it/s]


train epoch: 11, loss: 0.5074013094086026, rec: 117.7758975322705, topk: 0.22345917686078032
valid epoch: 11, loss: 2.3528976430604747, rec: 100.64430827884884, topk: 0.10768030617549103


100%|██████████| 1867/1867 [00:57<00:00, 32.47it/s]


train epoch: 12, loss: 0.5017709277598182, rec: 116.3920607073906, topk: 0.2244693527230246
valid epoch: 12, loss: 2.457051850609727, rec: 103.87707167405348, topk: 0.09942613983843857


100%|██████████| 1867/1867 [00:57<00:00, 32.23it/s]


train epoch: 13, loss: 0.49482603641710754, rec: 114.86037361296218, topk: 0.22490919503365447
valid epoch: 13, loss: 2.3146764226667176, rec: 103.85982790098086, topk: 0.10316227361871633


100%|██████████| 1867/1867 [00:57<00:00, 32.58it/s]


train epoch: 14, loss: 0.4881361583751603, rec: 113.27570393280863, topk: 0.224100372639881
valid epoch: 14, loss: 2.425076356300941, rec: 100.47506741115025, topk: 0.10725924485228637


100%|██████████| 1867/1867 [00:57<00:00, 32.67it/s]


train epoch: 15, loss: 0.4810872413317788, rec: 111.62397964305909, topk: 0.2239141198846739
valid epoch: 15, loss: 2.3180025349279028, rec: 109.94460451733936, topk: 0.10436176280578621


100%|██████████| 1867/1867 [00:57<00:00, 32.73it/s]


train epoch: 16, loss: 0.4746067031188387, rec: 110.17828733651601, topk: 0.22391975325651284
valid epoch: 16, loss: 2.3501097013007155, rec: 116.21787513481392, topk: 0.10319883995855918


100%|██████████| 1867/1867 [00:57<00:00, 32.63it/s]


train epoch: 17, loss: 0.466961472939219, rec: 108.47853752835013, topk: 0.22382289260202815
valid epoch: 17, loss: 2.4167692278112685, rec: 105.33131004165817, topk: 0.10393228268873539


100%|██████████| 1867/1867 [00:57<00:00, 32.64it/s]


train epoch: 18, loss: 0.45975393289924626, rec: 106.57046172065918, topk: 0.22466804434080506
valid epoch: 18, loss: 2.4877915398760155, rec: 108.41508607549981, topk: 0.1041999714013897


100%|██████████| 1867/1867 [00:57<00:00, 32.64it/s]


train epoch: 19, loss: 0.4522537743911171, rec: 104.9320339249875, topk: 0.22362265995895375
valid epoch: 19, loss: 2.4721046895771237, rec: 103.4971070761209, topk: 0.09892031262096332


100%|██████████| 1867/1867 [00:57<00:00, 32.39it/s]


train epoch: 20, loss: 0.4453026420079136, rec: 103.31549431055576, topk: 0.22469868651172595
valid epoch: 20, loss: 2.5119570591947533, rec: 106.74534776708582, topk: 0.09549705882607491


100%|██████████| 1867/1867 [00:56<00:00, 32.87it/s]


train epoch: 21, loss: 0.43787223290618116, rec: 101.49100778392757, topk: 0.2221233203508489
valid epoch: 21, loss: 2.7716301681248696, rec: 95.61096998361441, topk: 0.10009232126867175


100%|██████████| 1867/1867 [00:57<00:00, 32.34it/s]


train epoch: 22, loss: 0.43063990577823924, rec: 99.82187300744897, topk: 0.22149303699030762
valid epoch: 22, loss: 2.7841177576190823, rec: 96.85222386789846, topk: 0.09579228638442124


100%|██████████| 1867/1867 [00:58<00:00, 31.90it/s]


train epoch: 23, loss: 0.42368293239525057, rec: 98.01173404556708, topk: 0.22129344618203087
valid epoch: 23, loss: 2.8004819554286997, rec: 99.9300924028669, topk: 0.09743344879256494


100%|██████████| 1867/1867 [00:58<00:00, 31.85it/s]


train epoch: 24, loss: 0.41660215413117063, rec: 96.43606515893593, topk: 0.22013845291885106
valid epoch: 24, loss: 2.777795436618092, rec: 95.15139602828812, topk: 0.09482387284880492


100%|██████████| 1867/1867 [00:57<00:00, 32.56it/s]


train epoch: 25, loss: 0.41021585777346053, rec: 94.74756771061085, topk: 0.21873659694479033
valid epoch: 25, loss: 2.820357728790451, rec: 92.9377202463674, topk: 0.09331308906476606


100%|██████████| 1867/1867 [00:58<00:00, 32.18it/s]


train epoch: 26, loss: 0.40430883371261717, rec: 93.26246732968217, topk: 0.21790205326191153
valid epoch: 26, loss: 2.795144809471382, rec: 97.58856435922476, topk: 0.08872841130845835


100%|██████████| 1867/1867 [00:57<00:00, 32.22it/s]


train epoch: 27, loss: 0.3976326773175596, rec: 91.75496312750809, topk: 0.21942767735660557
valid epoch: 27, loss: 3.0455690036108205, rec: 91.69885863838616, topk: 0.07947509644261551


100%|██████████| 1867/1867 [00:58<00:00, 31.68it/s]


train epoch: 28, loss: 0.39198306962712026, rec: 90.35867240918022, topk: 0.2168464699167425
valid epoch: 28, loss: 2.927065229022896, rec: 96.01369193360046, topk: 0.08294846985021492


100%|██████████| 1867/1867 [00:57<00:00, 32.60it/s]


train epoch: 29, loss: 0.38577445560939566, rec: 88.72915635520147, topk: 0.21695910393882917
valid epoch: 29, loss: 2.9405050595383067, rec: 97.51323400224958, topk: 0.08022687792457599


In [17]:
train_loader = DataLoader(trainset, batch_size=512, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(validset, batch_size=512, shuffle=True, collate_fn=collate_fn)

model = ValModelv2().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

def loss_fn(logits, labels, margin=0.1, negative_weight=0.3):
    positive_mask = labels.type(torch.bool)
    positive_loss = torch.clamp(1.0 - logits[positive_mask]-margin, min=0).mean()
    negative_loss = logits[~positive_mask].mean()
    return positive_loss + negative_weight * negative_loss

In [18]:
for epoch in range(10):
    losses = []
    recs = []
    topks = []
    for user_embeds, labels in tqdm(train_loader):
        user_embeds = user_embeds.cuda()
        logits = model(user_embeds).cpu()
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        recs.append(logits.sum(dim=-1).mean().item())
        topks.append(eval_metric(logits, labels))
    print("train epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))

    losses = []
    recs = []
    topks = []
    model.eval()
    with torch.no_grad():
        for user_embeds, labels in valid_loader:
            user_embeds = user_embeds.cuda()
            logits = model(user_embeds).cpu()
            loss = loss_fn(logits, labels)
            losses.append(loss.item())
            recs.append(logits.sum(dim=-1).mean().item())
            topks.append(eval_metric(logits, labels))
    print("valid epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))
    model.train()

100%|██████████| 117/117 [00:44<00:00,  2.60it/s]


train epoch: 0, loss: 0.766678167714013, rec: 4.449799065151785, topk: 0.14409379729010702
valid epoch: 0, loss: 0.8116465029509171, rec: 2.4035574975221055, topk: 0.07683248175912097


100%|██████████| 117/117 [00:45<00:00,  2.59it/s]


train epoch: 1, loss: 0.7304033927428417, rec: 2.534630176348564, topk: 0.1774733906402815
valid epoch: 1, loss: 0.7995522151822629, rec: 1.314250567685003, topk: 0.08595529080869732


100%|██████████| 117/117 [00:45<00:00,  2.59it/s]


train epoch: 2, loss: 0.7056796688299912, rec: 0.48750328705208296, topk: 0.18878886096510775
valid epoch: 2, loss: 0.7931088001831718, rec: -0.5277866954388826, topk: 0.08964996539978001


100%|██████████| 117/117 [00:45<00:00,  2.58it/s]


train epoch: 3, loss: 0.6828133861223856, rec: -1.8289444759870186, topk: 0.19182157897608665
valid epoch: 3, loss: 0.7855333530384562, rec: -2.6289356273153555, topk: 0.08898914107559416


100%|██████████| 117/117 [00:45<00:00,  2.56it/s]


train epoch: 4, loss: 0.6599316439058027, rec: -4.2461754277221155, topk: 0.19353123097044944
valid epoch: 4, loss: 0.7791087083194567, rec: -5.1508641450301464, topk: 0.0884494247672093


100%|██████████| 117/117 [00:45<00:00,  2.57it/s]


train epoch: 5, loss: 0.6371049442861834, rec: -6.818087638952793, topk: 0.19334731663582247
valid epoch: 5, loss: 0.7741616549699203, rec: -7.222916872605033, topk: 0.08584312275092106


100%|██████████| 117/117 [00:45<00:00,  2.58it/s]


train epoch: 6, loss: 0.6141945052350688, rec: -9.374691555642674, topk: 0.19287406066290558
valid epoch: 6, loss: 0.7652964980705924, rec: -9.348411145417586, topk: 0.09144101392728285


100%|██████████| 117/117 [00:45<00:00,  2.57it/s]


train epoch: 7, loss: 0.5917456557608058, rec: -11.967124009743715, topk: 0.1919395911850983
valid epoch: 7, loss: 0.7585260842157446, rec: -12.235948977263078, topk: 0.08691540319209114


100%|██████████| 117/117 [00:45<00:00,  2.59it/s]


train epoch: 8, loss: 0.569560245061532, rec: -14.517706903636965, topk: 0.19170585932846912
valid epoch: 8, loss: 0.754508508288342, rec: -14.345203524050506, topk: 0.08492243145754241


100%|██████████| 117/117 [00:45<00:00,  2.57it/s]


train epoch: 9, loss: 0.5479295992443705, rec: -17.033421760950333, topk: 0.1903384165568237
valid epoch: 9, loss: 0.7471263201340385, rec: -16.769900695137355, topk: 0.08650586522202565


In [7]:
train_loader = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(validset, batch_size=32, shuffle=True, collate_fn=collate_fn)

model = ValModelv3().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

def loss_fn(logits, labels, margin=0.1, negative_weight=0.3):
    positive_mask = labels.type(torch.bool)
    positive_loss = torch.clamp(1.0 - logits[positive_mask]-margin, min=0).mean()
    negative_loss = logits[~positive_mask].mean()
    return positive_loss + negative_weight * negative_loss

  return torch.tensor(dataset)


In [8]:
for epoch in range(100):
    losses = []
    recs = []
    topks = []
    for user_embeds, labels in tqdm(train_loader):
        user_embeds = user_embeds.cuda()
        logits = model(user_embeds).cpu()
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        recs.append(logits.sum(dim=-1).mean().item())
        topks.append(eval_metric(logits, labels))
    print("train epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))

    losses = []
    recs = []
    topks = []
    model.eval()
    with torch.no_grad():
        for user_embeds, labels in valid_loader:
            user_embeds = user_embeds.cuda()
            logits = model(user_embeds).cpu()
            loss = loss_fn(logits, labels)
            losses.append(loss.item())
            recs.append(logits.sum(dim=-1).mean().item())
            topks.append(eval_metric(logits, labels))
    print("valid epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))
    model.train()

100%|██████████| 1867/1867 [03:11<00:00,  9.74it/s]


train epoch: 0, loss: 0.203832466253062, rec: 173.9995898638553, topk: 0.015023128256834932
valid epoch: 0, loss: 0.5313135508325074, rec: 88.25882869762378, topk: 0.01273372787846611


100%|██████████| 1867/1867 [03:10<00:00,  9.78it/s]


train epoch: 1, loss: 0.19324947941906642, rec: 102.03775221789401, topk: 0.015853206457859068
valid epoch: 1, loss: 0.8500865481712006, rec: -21.5492885021063, topk: 0.0053002652279394


  5%|▍         | 93/1867 [00:09<03:01,  9.80it/s]


KeyboardInterrupt: 

In [7]:
train_loader = DataLoader(trainset, batch_size=512, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(validset, batch_size=512, shuffle=True, collate_fn=collate_fn)

model = ValModelv4().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [8]:
for epoch in range(30):
    losses = []
    recs = []
    topks = []
    for user_embeds, labels in tqdm(train_loader):
        user_embeds = user_embeds.cuda()
        logits = model(user_embeds).cpu()
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        recs.append(torch.sigmoid(logits).sum(dim=-1).mean().item())
        topks.append(eval_metric(logits, labels))
    print("train epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))

    losses = []
    recs = []
    topks = []
    model.eval()
    with torch.no_grad():
        for user_embeds, labels in valid_loader:
            user_embeds = user_embeds.cuda()
            logits = model(user_embeds).cpu()
            loss = loss_fn(logits, labels)
            losses.append(loss.item())
            recs.append(torch.sigmoid(logits).sum(dim=-1).mean().item())
            topks.append(eval_metric(logits, labels))
    print("valid epoch: {}, loss: {}, rec: {}, topk: {}".format(epoch, np.mean(losses), np.mean(recs), np.mean(topks)))
    model.train()

 50%|████▉     | 58/117 [00:32<00:33,  1.78it/s]


KeyboardInterrupt: 