In [1]:
import pandas as pd
import seaborn as sns
import numpy as np
import sklearn
from sklearn.model_selection import train_test_split
import torch
import math
import pickle


In [32]:
#split the dataset

dataset = pickle.load(open('datasets.pickle','rb'))

batchsize = 256
#train_x,test_x,train_y,test_y = train_test_split(torch.tensor(X),torch.tensor(Y),test_size=0.2,stratify=torch.tensor(Y))
#train_x,val_x,train_y,val_y = train_test_split(train_x,train_y,test_size=0.2,stratify=train_y)
#trainset = torch.utils.data.TensorDataset(train_x,train_y)
trainset = dataset[0]
trainloader = torch.utils.data.DataLoader(trainset,batch_size=batchsize,shuffle=True)

#valset = torch.utils.data.TensorDataset(val_x,val_y)
valset = dataset[1]
valloader = torch.utils.data.DataLoader(valset,batch_size=batchsize,shuffle=True)

#testset = torch.utils.data.TensorDataset(test_x,test_y)

testset = dataset[2]
testloader = torch.utils.data.DataLoader(testset,batch_size=batchsize,shuffle=True)



print(trainset[0][0][:4],trainset[0][1].shape)

input_dim = 4
output_dim = trainset[0][1].shape[0]

tensor([ 209, 5287,    5,    2]) torch.Size([5911])


In [117]:

#define the model

class MLP(torch.nn.Module):

	def __init__(self, in_channels, hidden_channels, out_channels):
		super().__init__()
		self.emb1 = torch.nn.Embedding(100000,32)
		self.emb2 = torch.nn.Embedding(100000,32)
		self.timeday = torch.nn.Linear(2,32)
		self.class1 = torch.nn.Linear(96, hidden_channels)
		self.class2 = torch.nn.Linear(hidden_channels, 32)
		self.class3 = torch.nn.Linear(32,16)
		self.class4 = torch.nn.Linear(16, out_channels)



	def forward(self,data):


		user = self.emb1(data[:,0])
		loc = self.emb2(data[:, 1])
		#print(user.shape,loc.shape)
		hour = data[:,2]
		day = data[:,3]
		timeday = torch.stack([hour,day],dim=-1)
		#print(timeday.shape)
		timeday = self.timeday(timeday.float())
		x = torch.cat((user,loc,timeday),dim=1)
		#print(x.shape)
		x = self.class1(x).relu()
		x = self.class2(x).relu()
		x = self.class3(x).relu()
		x = self.class4(x)

		return x

In [118]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

lr = 0.001
wd = 5e-4

model = MLP(4,64,output_dim).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, weight_decay=wd)
criterion = torch.nn.CrossEntropyLoss()

print(model)


Using device: cuda
MLP(
  (emb1): Embedding(100000, 32)
  (emb2): Embedding(100000, 32)
  (timeday): Linear(in_features=2, out_features=32, bias=True)
  (class1): Linear(in_features=96, out_features=64, bias=True)
  (class2): Linear(in_features=64, out_features=32, bias=True)
  (class3): Linear(in_features=32, out_features=16, bias=True)
  (class4): Linear(in_features=16, out_features=5911, bias=True)
)


In [120]:
import tqdm
numepoch = 500

#training phase

with tqdm.tqdm(total=numepoch*len(trainloader), unit='batches') as pbar:
  losses = []
  for epoch in range(numepoch):

    model.train()
    optimizer.zero_grad()

    print(f"EPOCH = {epoch}")

    for batch in trainloader:

      input = batch[0].to(device)
      label = batch[1].to(device)

      #print(input.shape)
      #print(label.shape)
      #one_hot_encode_labels(label)
      #print(input.shape,input[:,:4].shape)

      out = model(input[:,:4])

      #print(out.shape)
      #out = model(input)
      #out=out.squeeze(dim=1)
      #print(out)


      loss = criterion(out,label)
      losses.append(loss)
      loss.backward()
      optimizer.step()



      pbar.set_postfix({'Loss': loss.item()})
      pbar.update()


      #print(df)


  0%|          | 6/1500 [00:00<00:51, 29.07batches/s, Loss=2]   

EPOCH = 0
EPOCH = 1
EPOCH = 2


  1%|▏         | 22/1500 [00:00<00:29, 50.21batches/s, Loss=1.64]

EPOCH = 3
EPOCH = 4
EPOCH = 5
EPOCH = 6
EPOCH = 7


  3%|▎         | 42/1500 [00:00<00:18, 79.21batches/s, Loss=1.21]

EPOCH = 8
EPOCH = 9
EPOCH = 10
EPOCH = 11
EPOCH = 12
EPOCH = 13
EPOCH = 14


  4%|▍         | 62/1500 [00:00<00:16, 87.99batches/s, Loss=1.2]  

EPOCH = 15
EPOCH = 16
EPOCH = 17
EPOCH = 18
EPOCH = 19
EPOCH = 20
EPOCH = 21


  6%|▌         | 84/1500 [00:01<00:15, 91.18batches/s, Loss=0.8]

EPOCH = 22
EPOCH = 23
EPOCH = 24
EPOCH = 25
EPOCH = 26
EPOCH = 27
EPOCH = 28


  7%|▋         | 104/1500 [00:01<00:14, 93.91batches/s, Loss=0.438]

EPOCH = 29
EPOCH = 30
EPOCH = 31
EPOCH = 32
EPOCH = 33
EPOCH = 34
EPOCH = 35


  8%|▊         | 119/1500 [00:01<00:17, 78.25batches/s, Loss=0.349]

EPOCH = 36
EPOCH = 37
EPOCH = 38
EPOCH = 39
EPOCH = 40


  9%|▉         | 140/1500 [00:01<00:16, 81.14batches/s, Loss=0.374]

EPOCH = 41
EPOCH = 42
EPOCH = 43
EPOCH = 44
EPOCH = 45
EPOCH = 46
EPOCH = 47

 10%|█         | 156/1500 [00:01<00:16, 81.52batches/s, Loss=0.257]


EPOCH = 48
EPOCH = 49
EPOCH = 50
EPOCH = 51
EPOCH = 52


 12%|█▏        | 176/1500 [00:02<00:15, 85.06batches/s, Loss=0.227]

EPOCH = 53
EPOCH = 54
EPOCH = 55
EPOCH = 56
EPOCH = 57
EPOCH = 58
EPOCH = 59


 13%|█▎        | 198/1500 [00:02<00:14, 89.90batches/s, Loss=0.148]

EPOCH = 60
EPOCH = 61
EPOCH = 62
EPOCH = 63
EPOCH = 64
EPOCH = 65
EPOCH = 66


 15%|█▍        | 218/1500 [00:02<00:13, 91.89batches/s, Loss=0.112]

EPOCH = 67
EPOCH = 68
EPOCH = 69
EPOCH = 70
EPOCH = 71
EPOCH = 72
EPOCH = 73

 16%|█▌        | 236/1500 [00:02<00:14, 88.77batches/s, Loss=0.102]


EPOCH = 74
EPOCH = 75
EPOCH = 76
EPOCH = 77
EPOCH = 78
EPOCH = 79


 17%|█▋        | 257/1500 [00:03<00:13, 90.93batches/s, Loss=0.103] 

EPOCH = 80
EPOCH = 81
EPOCH = 82
EPOCH = 83
EPOCH = 84
EPOCH = 85
EPOCH = 86


 19%|█▊        | 278/1500 [00:03<00:13, 91.47batches/s, Loss=0.0639]

EPOCH = 87
EPOCH = 88
EPOCH = 89
EPOCH = 90
EPOCH = 91
EPOCH = 92
EPOCH = 93


 20%|██        | 300/1500 [00:03<00:12, 93.92batches/s, Loss=0.0698]

EPOCH = 94
EPOCH = 95
EPOCH = 96
EPOCH = 97
EPOCH = 98
EPOCH = 99
EPOCH = 100


 21%|██▏       | 321/1500 [00:03<00:12, 94.30batches/s, Loss=0.0621]

EPOCH = 101
EPOCH = 102
EPOCH = 103
EPOCH = 104
EPOCH = 105
EPOCH = 106
EPOCH = 107


 23%|██▎       | 342/1500 [00:03<00:12, 94.93batches/s, Loss=0.0537]

EPOCH = 108
EPOCH = 109
EPOCH = 110
EPOCH = 111
EPOCH = 112
EPOCH = 113
EPOCH = 114


 24%|██▍       | 362/1500 [00:04<00:12, 91.36batches/s, Loss=0.0528]

EPOCH = 115
EPOCH = 116
EPOCH = 117
EPOCH = 118
EPOCH = 119
EPOCH = 120
EPOCH = 121


 26%|██▌       | 384/1500 [00:04<00:11, 94.58batches/s, Loss=0.0463]

EPOCH = 122
EPOCH = 123
EPOCH = 124
EPOCH = 125
EPOCH = 126
EPOCH = 127
EPOCH = 128


 27%|██▋       | 403/1500 [00:04<00:12, 90.84batches/s, Loss=0.045] 

EPOCH = 129
EPOCH = 130
EPOCH = 131
EPOCH = 132
EPOCH = 133
EPOCH = 134


 28%|██▊       | 422/1500 [00:04<00:11, 91.54batches/s, Loss=0.0524]

EPOCH = 135
EPOCH = 136
EPOCH = 137
EPOCH = 138
EPOCH = 139
EPOCH = 140


 29%|██▉       | 440/1500 [00:05<00:12, 88.06batches/s, Loss=0.0369]

EPOCH = 141
EPOCH = 142
EPOCH = 143
EPOCH = 144
EPOCH = 145
EPOCH = 146
EPOCH = 147


 31%|███       | 461/1500 [00:05<00:11, 91.68batches/s, Loss=0.0384]

EPOCH = 148
EPOCH = 149
EPOCH = 150
EPOCH = 151
EPOCH = 152
EPOCH = 153
EPOCH = 154


 32%|███▏      | 484/1500 [00:05<00:11, 91.85batches/s, Loss=0.0353]

EPOCH = 155
EPOCH = 156
EPOCH = 157
EPOCH = 158
EPOCH = 159
EPOCH = 160
EPOCH = 161


 34%|███▎      | 504/1500 [00:05<00:10, 91.03batches/s, Loss=0.039]

EPOCH = 162
EPOCH = 163
EPOCH = 164
EPOCH = 165
EPOCH = 166
EPOCH = 167


 35%|███▍      | 520/1500 [00:05<00:11, 88.80batches/s, Loss=0.0346]

EPOCH = 168
EPOCH = 169
EPOCH = 170
EPOCH = 171
EPOCH = 172
EPOCH = 173


 36%|███▌      | 539/1500 [00:06<00:10, 88.57batches/s, Loss=0.0289]

EPOCH = 174
EPOCH = 175
EPOCH = 176
EPOCH = 177
EPOCH = 178
EPOCH = 179
EPOCH = 180


 37%|███▋      | 561/1500 [00:06<00:10, 89.48batches/s, Loss=0.0378]

EPOCH = 181
EPOCH = 182
EPOCH = 183
EPOCH = 184
EPOCH = 185
EPOCH = 186


 38%|███▊      | 577/1500 [00:06<00:10, 87.54batches/s, Loss=0.0305]

EPOCH = 187
EPOCH = 188
EPOCH = 189
EPOCH = 190
EPOCH = 191
EPOCH = 192


 40%|███▉      | 597/1500 [00:06<00:10, 87.79batches/s, Loss=0.0306]

EPOCH = 193
EPOCH = 194
EPOCH = 195
EPOCH = 196
EPOCH = 197
EPOCH = 198
EPOCH = 199


 41%|████      | 615/1500 [00:07<00:10, 86.50batches/s, Loss=0.0297]

EPOCH = 200
EPOCH = 201
EPOCH = 202
EPOCH = 203
EPOCH = 204
EPOCH = 205


 42%|████▏     | 634/1500 [00:07<00:09, 87.91batches/s, Loss=0.0289]

EPOCH = 206
EPOCH = 207
EPOCH = 208
EPOCH = 209
EPOCH = 210
EPOCH = 211


 43%|████▎     | 652/1500 [00:07<00:09, 86.58batches/s, Loss=0.0301]

EPOCH = 212
EPOCH = 213
EPOCH = 214
EPOCH = 215
EPOCH = 216
EPOCH = 217


 45%|████▍     | 671/1500 [00:07<00:09, 88.37batches/s, Loss=0.0285]

EPOCH = 218
EPOCH = 219
EPOCH = 220
EPOCH = 221
EPOCH = 222
EPOCH = 223
EPOCH = 224


 46%|████▌     | 691/1500 [00:07<00:09, 87.64batches/s, Loss=0.0275]

EPOCH = 225
EPOCH = 226
EPOCH = 227
EPOCH = 228
EPOCH = 229
EPOCH = 230


 47%|████▋     | 710/1500 [00:08<00:09, 86.59batches/s, Loss=0.0259]

EPOCH = 231
EPOCH = 232
EPOCH = 233
EPOCH = 234
EPOCH = 235
EPOCH = 236
EPOCH = 237


 49%|████▊     | 730/1500 [00:08<00:08, 87.62batches/s, Loss=0.025] 

EPOCH = 238
EPOCH = 239
EPOCH = 240
EPOCH = 241
EPOCH = 242
EPOCH = 243


 50%|████▉     | 748/1500 [00:08<00:08, 88.78batches/s, Loss=0.0249]

EPOCH = 244
EPOCH = 245
EPOCH = 246
EPOCH = 247
EPOCH = 248
EPOCH = 249


 51%|█████     | 766/1500 [00:08<00:08, 87.71batches/s, Loss=0.0244]

EPOCH = 250
EPOCH = 251
EPOCH = 252
EPOCH = 253
EPOCH = 254
EPOCH = 255


 52%|█████▏    | 783/1500 [00:08<00:08, 85.44batches/s, Loss=0.0245]

EPOCH = 256
EPOCH = 257
EPOCH = 258
EPOCH = 259
EPOCH = 260
EPOCH = 261


 53%|█████▎    | 799/1500 [00:09<00:08, 77.92batches/s, Loss=0.0221]

EPOCH = 262
EPOCH = 263
EPOCH = 264
EPOCH = 265
EPOCH = 266


 54%|█████▍    | 814/1500 [00:09<00:08, 76.94batches/s, Loss=0.0232]

EPOCH = 267
EPOCH = 268
EPOCH = 269
EPOCH = 270
EPOCH = 271


 55%|█████▌    | 830/1500 [00:09<00:08, 74.49batches/s, Loss=0.0234]

EPOCH = 272
EPOCH = 273
EPOCH = 274
EPOCH = 275
EPOCH = 276


 56%|█████▋    | 846/1500 [00:09<00:08, 76.56batches/s, Loss=0.0232]

EPOCH = 277
EPOCH = 278
EPOCH = 279
EPOCH = 280
EPOCH = 281
EPOCH = 282


 58%|█████▊    | 865/1500 [00:10<00:07, 79.50batches/s, Loss=0.0214]

EPOCH = 283
EPOCH = 284
EPOCH = 285
EPOCH = 286
EPOCH = 287
EPOCH = 288


 59%|█████▉    | 884/1500 [00:10<00:07, 84.46batches/s, Loss=0.019] 

EPOCH = 289
EPOCH = 290
EPOCH = 291
EPOCH = 292
EPOCH = 293
EPOCH = 294
EPOCH = 295


 60%|██████    | 902/1500 [00:10<00:07, 82.58batches/s, Loss=0.0192]

EPOCH = 296
EPOCH = 297
EPOCH = 298
EPOCH = 299
EPOCH = 300
EPOCH = 301


 61%|██████▏   | 919/1500 [00:10<00:07, 74.84batches/s, Loss=0.0207]

EPOCH = 302
EPOCH = 303
EPOCH = 304
EPOCH = 305
EPOCH = 306


 62%|██████▏   | 933/1500 [00:10<00:07, 71.73batches/s, Loss=0.0207]

EPOCH = 307
EPOCH = 308
EPOCH = 309
EPOCH = 310
EPOCH = 311


 63%|██████▎   | 949/1500 [00:11<00:07, 69.71batches/s, Loss=0.0186]

EPOCH = 312
EPOCH = 313
EPOCH = 314
EPOCH = 315
EPOCH = 316


 64%|██████▍   | 965/1500 [00:11<00:07, 70.22batches/s, Loss=0.0198]

EPOCH = 317
EPOCH = 318
EPOCH = 319
EPOCH = 320
EPOCH = 321
EPOCH = 322


 66%|██████▌   | 983/1500 [00:11<00:06, 74.29batches/s, Loss=0.0182]

EPOCH = 323
EPOCH = 324
EPOCH = 325
EPOCH = 326
EPOCH = 327
EPOCH = 328


 67%|██████▋   | 1001/1500 [00:11<00:06, 76.06batches/s, Loss=0.0209]

EPOCH = 329
EPOCH = 330
EPOCH = 331
EPOCH = 332
EPOCH = 333
EPOCH = 334


 68%|██████▊   | 1019/1500 [00:12<00:06, 73.79batches/s, Loss=0.0198]

EPOCH = 335
EPOCH = 336
EPOCH = 337
EPOCH = 338
EPOCH = 339


 69%|██████▉   | 1033/1500 [00:12<00:06, 74.68batches/s, Loss=0.0188]

EPOCH = 340
EPOCH = 341
EPOCH = 342
EPOCH = 343
EPOCH = 344


 70%|██████▉   | 1047/1500 [00:12<00:06, 71.57batches/s, Loss=0.0191]

EPOCH = 345
EPOCH = 346
EPOCH = 347
EPOCH = 348
EPOCH = 349


 71%|███████   | 1066/1500 [00:12<00:05, 76.51batches/s, Loss=0.0182]

EPOCH = 350
EPOCH = 351
EPOCH = 352
EPOCH = 353
EPOCH = 354
EPOCH = 355


 72%|███████▏  | 1084/1500 [00:12<00:05, 81.79batches/s, Loss=0.0192]

EPOCH = 356
EPOCH = 357
EPOCH = 358
EPOCH = 359
EPOCH = 360
EPOCH = 361


 73%|███████▎  | 1102/1500 [00:13<00:04, 83.75batches/s, Loss=0.0175]

EPOCH = 362
EPOCH = 363
EPOCH = 364
EPOCH = 365
EPOCH = 366
EPOCH = 367


 75%|███████▍  | 1120/1500 [00:13<00:04, 85.44batches/s, Loss=0.0229]

EPOCH = 368
EPOCH = 369
EPOCH = 370
EPOCH = 371
EPOCH = 372
EPOCH = 373


 76%|███████▌  | 1137/1500 [00:13<00:04, 85.49batches/s, Loss=0.0185]

EPOCH = 374
EPOCH = 375
EPOCH = 376
EPOCH = 377
EPOCH = 378
EPOCH = 379


 77%|███████▋  | 1157/1500 [00:13<00:03, 88.66batches/s, Loss=0.0221]

EPOCH = 380
EPOCH = 381
EPOCH = 382
EPOCH = 383
EPOCH = 384
EPOCH = 385
EPOCH = 386


 79%|███████▊  | 1178/1500 [00:13<00:03, 89.09batches/s, Loss=0.0192]

EPOCH = 387
EPOCH = 388
EPOCH = 389
EPOCH = 390
EPOCH = 391
EPOCH = 392
EPOCH = 393


 80%|████████  | 1200/1500 [00:14<00:03, 91.42batches/s, Loss=0.0217]

EPOCH = 394
EPOCH = 395
EPOCH = 396
EPOCH = 397
EPOCH = 398
EPOCH = 399
EPOCH = 400


 81%|████████▏ | 1221/1500 [00:14<00:03, 91.52batches/s, Loss=0.0149]

EPOCH = 401
EPOCH = 402
EPOCH = 403
EPOCH = 404
EPOCH = 405
EPOCH = 406
EPOCH = 407


 83%|████████▎ | 1241/1500 [00:14<00:02, 92.91batches/s, Loss=0.0221]

EPOCH = 408
EPOCH = 409
EPOCH = 410
EPOCH = 411
EPOCH = 412
EPOCH = 413
EPOCH = 414


 84%|████████▍ | 1262/1500 [00:14<00:02, 93.29batches/s, Loss=0.0133]

EPOCH = 415
EPOCH = 416
EPOCH = 417
EPOCH = 418
EPOCH = 419
EPOCH = 420
EPOCH = 421


 85%|████████▌ | 1282/1500 [00:15<00:02, 88.57batches/s, Loss=0.0178]

EPOCH = 422
EPOCH = 423
EPOCH = 424
EPOCH = 425
EPOCH = 426
EPOCH = 427


 87%|████████▋ | 1300/1500 [00:15<00:02, 86.66batches/s, Loss=0.0163]

EPOCH = 428
EPOCH = 429
EPOCH = 430
EPOCH = 431
EPOCH = 432
EPOCH = 433


 88%|████████▊ | 1317/1500 [00:15<00:02, 84.14batches/s, Loss=0.0162]

EPOCH = 434
EPOCH = 435
EPOCH = 436
EPOCH = 437
EPOCH = 438
EPOCH = 439


 89%|████████▉ | 1335/1500 [00:15<00:01, 83.95batches/s, Loss=0.0154]

EPOCH = 440
EPOCH = 441
EPOCH = 442
EPOCH = 443
EPOCH = 444
EPOCH = 445


 90%|█████████ | 1353/1500 [00:15<00:01, 83.01batches/s, Loss=0.0155]

EPOCH = 446
EPOCH = 447
EPOCH = 448
EPOCH = 449
EPOCH = 450
EPOCH = 451


 91%|█████████▏| 1371/1500 [00:16<00:01, 82.70batches/s, Loss=0.0147]

EPOCH = 452
EPOCH = 453
EPOCH = 454
EPOCH = 455
EPOCH = 456
EPOCH = 457


 93%|█████████▎| 1389/1500 [00:16<00:01, 84.39batches/s, Loss=0.0147]

EPOCH = 458
EPOCH = 459
EPOCH = 460
EPOCH = 461
EPOCH = 462
EPOCH = 463


 94%|█████████▍| 1409/1500 [00:16<00:01, 88.74batches/s, Loss=0.0185]

EPOCH = 464
EPOCH = 465
EPOCH = 466
EPOCH = 467
EPOCH = 468
EPOCH = 469
EPOCH = 470


 95%|█████████▌| 1429/1500 [00:16<00:00, 89.13batches/s, Loss=0.0145]

EPOCH = 471
EPOCH = 472
EPOCH = 473
EPOCH = 474
EPOCH = 475
EPOCH = 476


 96%|█████████▋| 1447/1500 [00:17<00:00, 87.85batches/s, Loss=0.0166]

EPOCH = 477
EPOCH = 478
EPOCH = 479
EPOCH = 480
EPOCH = 481
EPOCH = 482


 98%|█████████▊| 1466/1500 [00:17<00:00, 87.72batches/s, Loss=0.0187]

EPOCH = 483
EPOCH = 484
EPOCH = 485
EPOCH = 486
EPOCH = 487
EPOCH = 488


 99%|█████████▉| 1485/1500 [00:17<00:00, 92.06batches/s, Loss=0.0143]

EPOCH = 489
EPOCH = 490
EPOCH = 491
EPOCH = 492
EPOCH = 493
EPOCH = 494
EPOCH = 495


100%|██████████| 1500/1500 [00:17<00:00, 85.03batches/s, Loss=0.0143]

EPOCH = 496
EPOCH = 497
EPOCH = 498
EPOCH = 499





In [None]:
import matplotlib.pyplot as plt
#print CE loss evolution
ind = [i for i in range(len(losses))]
losses = [loss.cpu().detach().numpy() for loss in losses]
plt.plot(ind,losses)
plt.show()

In [124]:
import tqdm

criterion = torch.nn.CrossEntropyLoss()

#testing phase

with tqdm.tqdm(total=len(testloader), unit='batches') as pbar:
  losses = []

  model.eval()

  for batch in testloader:

    input = batch[0].to(device)
    label = batch[1].to(device)

    out = model(input[:,:4])

    loss = criterion(out,label)

    losses.append(loss)

    pbar.set_postfix({'Loss': loss.item()})
    pbar.update()

print(sum(losses)/len(losses))


100%|██████████| 1/1 [00:00<00:00, 107.00batches/s, Loss=40]

tensor(40.0012, device='cuda:0', grad_fn=<DivBackward0>)





In [None]:
import matplotlib.pyplot as plt

#print loss evolution

ind = [i for i in range(len(losses))]
losses = [loss.cpu().detach().numpy() for loss in losses]
plt.plot(ind,losses)
plt.show()