In [1]:
import torch

In [2]:
import torch.nn as nn

In [3]:
#basic imports for data preprocessing
import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder
#import data
data = pd.read_csv("data/train.csv")
#check out the data


In [4]:
data['date']=pd.to_datetime(data['date'])
#don't use date because model cannot understand raw date data
#instead collect useful data from date
data['weekday'] = data['date'].map(lambda x : x.dayofweek).astype('category')
data['month'] = data['date'].map(lambda x : x.month).astype('category')
data['monthday'] = data['date'].map(lambda x : x.day).astype('category')
#year is the only continuous variable, it needs to be scaled
data['year'] = data['date'].map(lambda x : x.year - 2016).astype(float)
data['year'] = data['year'].map(lambda x : (x - 1)/(5-1))
dataf = data.drop('row_id',axis=1).drop('date',axis=1).sample(frac=1).reset_index()
dataf.drop('index',axis=1,inplace=True)
y = dataf['num_sold'].to_numpy().astype('float32').reshape(-1,1)
dataf.drop('num_sold',axis=1,inplace=True)


In [5]:
ENC = OneHotEncoder(sparse=False)
transforms = ENC.fit_transform(dataf[['weekday','month','monthday','product','store','country']])
year = dataf['year'].to_numpy().reshape(-1,1)



In [6]:
full_data = np.concatenate((transforms,year,y),axis=1)
training_data = full_data[:68128]
test_data = full_data[68128:]

In [7]:
x_train = training_data[:,:63][:]
y_train = training_data[:,-1:][:]
x_test = test_data[:,:63]
y_test = test_data[:,-1:]

In [8]:
from torch.utils.data import DataLoader, Dataset

In [9]:
x_train = torch.tensor(x_train,dtype=torch.float32)
y_train = torch.tensor(y_train,dtype=torch.float32)
x_test = torch.tensor(x_test,dtype=torch.float32)
y_test = torch.tensor(y_test,dtype=torch.float32)

In [10]:
class DataSetMaker:
    def __init__(self,X,y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self,idx):
        return [self.X[idx], self.y[idx]]
training_data =  DataSetMaker(x_train,y_train)
test_data = DataSetMaker(x_test,y_test)

training_data_loader = DataLoader(training_data,batch_size=64,shuffle=True)
test_data_loader = DataLoader(test_data,batch_size=64,shuffle=False)

In [11]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(63, 512),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(512, 1),
        )

    def forward(self, x):

        pred = self.linear_relu_stack(x)
        return pred

model = NeuralNetwork()


In [12]:
learning_rate = 1e-3
batch_size = 64
epochs = 5

In [13]:
loss = nn.MSELoss()

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [15]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss(pred, y).item()


    test_loss /= num_batches

    print(f"Test Error: Avg loss: {test_loss:>8f} \n")

In [None]:
model.float()
epochs = 1000
for t in range(epochs):

    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(training_data_loader, model, loss, optimizer)
    test_loop(test_data_loader, model, loss)
print("Done!")

Epoch 1
-------------------------------
loss: 49743.015625  [    0/68128]
loss: 2875.901123  [ 6400/68128]
loss: 1119.693237  [12800/68128]
loss: 1668.882812  [19200/68128]
loss: 1508.697144  [25600/68128]
loss: 1308.741699  [32000/68128]
loss: 1685.847656  [38400/68128]
loss: 1332.372437  [44800/68128]
loss: 1048.627686  [51200/68128]
loss: 1223.028076  [57600/68128]
loss: 940.165588  [64000/68128]
Test Error: Avg loss: 1534.681019 

Epoch 2
-------------------------------
loss: 1863.972168  [    0/68128]
loss: 1671.382568  [ 6400/68128]
loss: 1349.871216  [12800/68128]
loss: 1254.911133  [19200/68128]
loss: 1017.981445  [25600/68128]
loss: 827.617981  [32000/68128]
loss: 1512.383667  [38400/68128]
loss: 692.443359  [44800/68128]
loss: 824.631470  [51200/68128]
loss: 902.957214  [57600/68128]
loss: 685.274658  [64000/68128]
Test Error: Avg loss: 1019.102509 

Epoch 3
-------------------------------
loss: 1320.071899  [    0/68128]
loss: 936.997437  [ 6400/68128]
loss: 863.655151  [128

loss: 330.496246  [ 6400/68128]
loss: 677.062805  [12800/68128]
loss: 786.256226  [19200/68128]
loss: 922.792480  [25600/68128]
loss: 398.424377  [32000/68128]
loss: 621.566895  [38400/68128]
loss: 749.688110  [44800/68128]
loss: 792.137390  [51200/68128]
loss: 278.557190  [57600/68128]
loss: 743.488708  [64000/68128]
Test Error: Avg loss: 647.843368 

Epoch 21
-------------------------------
loss: 730.115845  [    0/68128]
loss: 389.496155  [ 6400/68128]
loss: 918.263428  [12800/68128]
loss: 634.579407  [19200/68128]
loss: 686.848511  [25600/68128]
loss: 495.172546  [32000/68128]
loss: 583.144348  [38400/68128]
loss: 509.969574  [44800/68128]
loss: 1249.435913  [51200/68128]
loss: 696.661987  [57600/68128]
loss: 548.841614  [64000/68128]
Test Error: Avg loss: 634.213858 

Epoch 22
-------------------------------
loss: 645.481689  [    0/68128]
loss: 543.169128  [ 6400/68128]
loss: 699.568848  [12800/68128]
loss: 551.160950  [19200/68128]
loss: 699.915283  [25600/68128]
loss: 300.63873

loss: 442.989410  [19200/68128]
loss: 688.667480  [25600/68128]
loss: 836.024963  [32000/68128]
loss: 332.477875  [38400/68128]
loss: 456.714844  [44800/68128]
loss: 389.827667  [51200/68128]
loss: 384.477509  [57600/68128]
loss: 968.086121  [64000/68128]
Test Error: Avg loss: 596.001606 

Epoch 40
-------------------------------
loss: 566.229004  [    0/68128]
loss: 764.298340  [ 6400/68128]
loss: 306.882507  [12800/68128]
loss: 388.737518  [19200/68128]
loss: 483.152924  [25600/68128]
loss: 492.205261  [32000/68128]
loss: 987.188660  [38400/68128]
loss: 878.351746  [44800/68128]
loss: 993.686096  [51200/68128]
loss: 519.092957  [57600/68128]
loss: 538.670959  [64000/68128]
Test Error: Avg loss: 558.371740 

Epoch 41
-------------------------------
loss: 664.250061  [    0/68128]
loss: 521.034180  [ 6400/68128]
loss: 589.800537  [12800/68128]
loss: 368.202850  [19200/68128]
loss: 456.763367  [25600/68128]
loss: 666.972046  [32000/68128]
loss: 399.870758  [38400/68128]
loss: 726.199646

loss: 534.609375  [32000/68128]
loss: 433.623627  [38400/68128]
loss: 405.429108  [44800/68128]
loss: 533.655884  [51200/68128]
loss: 306.386322  [57600/68128]
loss: 216.833466  [64000/68128]
Test Error: Avg loss: 503.676199 

Epoch 59
-------------------------------
loss: 324.456512  [    0/68128]
loss: 671.412598  [ 6400/68128]
loss: 863.354614  [12800/68128]
loss: 392.087769  [19200/68128]
loss: 521.245911  [25600/68128]
loss: 336.899170  [32000/68128]
loss: 684.961304  [38400/68128]
loss: 482.835724  [44800/68128]
loss: 627.283203  [51200/68128]
loss: 960.389893  [57600/68128]
loss: 613.183105  [64000/68128]
Test Error: Avg loss: 600.685473 

Epoch 60
-------------------------------
loss: 607.810303  [    0/68128]
loss: 313.381897  [ 6400/68128]
loss: 605.054565  [12800/68128]
loss: 325.911743  [19200/68128]
loss: 517.801514  [25600/68128]
loss: 603.553833  [32000/68128]
loss: 281.613831  [38400/68128]
loss: 545.409180  [44800/68128]
loss: 375.363617  [51200/68128]
loss: 588.192505

loss: 446.549225  [44800/68128]
loss: 331.964478  [51200/68128]
loss: 287.715118  [57600/68128]
loss: 507.264008  [64000/68128]
Test Error: Avg loss: 526.259172 

Epoch 78
-------------------------------
loss: 351.173340  [    0/68128]
loss: 626.580383  [ 6400/68128]
loss: 580.677795  [12800/68128]
loss: 483.118958  [19200/68128]
loss: 583.104004  [25600/68128]
loss: 555.799194  [32000/68128]
loss: 402.074707  [38400/68128]
loss: 423.902649  [44800/68128]
loss: 697.336121  [51200/68128]
loss: 332.777283  [57600/68128]
loss: 248.392441  [64000/68128]
Test Error: Avg loss: 514.292119 

Epoch 79
-------------------------------
loss: 396.276398  [    0/68128]
loss: 532.640564  [ 6400/68128]
loss: 529.941040  [12800/68128]
loss: 544.787170  [19200/68128]
loss: 369.299957  [25600/68128]
loss: 498.717743  [32000/68128]
loss: 410.763031  [38400/68128]
loss: 389.465454  [44800/68128]
loss: 500.734741  [51200/68128]
loss: 495.906403  [57600/68128]
loss: 382.254700  [64000/68128]
Test Error: Avg 

loss: 338.492737  [57600/68128]
loss: 656.671204  [64000/68128]
Test Error: Avg loss: 516.051714 

Epoch 97
-------------------------------
loss: 470.389771  [    0/68128]
loss: 581.200745  [ 6400/68128]
loss: 261.871552  [12800/68128]
loss: 512.972290  [19200/68128]
loss: 511.405640  [25600/68128]
loss: 417.771332  [32000/68128]
loss: 317.003906  [38400/68128]
loss: 263.725830  [44800/68128]
loss: 277.388733  [51200/68128]
loss: 385.701538  [57600/68128]
loss: 483.234436  [64000/68128]
Test Error: Avg loss: 502.607037 

Epoch 98
-------------------------------
loss: 275.673157  [    0/68128]
loss: 383.114258  [ 6400/68128]
loss: 409.471954  [12800/68128]
loss: 389.534485  [19200/68128]
loss: 341.851532  [25600/68128]
loss: 501.318390  [32000/68128]
loss: 313.635376  [38400/68128]
loss: 306.003082  [44800/68128]
loss: 559.398682  [51200/68128]
loss: 609.320618  [57600/68128]
loss: 616.085510  [64000/68128]
Test Error: Avg loss: 487.777540 

Epoch 99
-------------------------------
loss

Test Error: Avg loss: 509.900705 

Epoch 116
-------------------------------
loss: 238.003052  [    0/68128]
loss: 407.294464  [ 6400/68128]
loss: 746.819031  [12800/68128]
loss: 324.819000  [19200/68128]
loss: 435.038696  [25600/68128]
loss: 436.424561  [32000/68128]
loss: 306.291473  [38400/68128]
loss: 325.135590  [44800/68128]
loss: 256.655731  [51200/68128]
loss: 240.485153  [57600/68128]
loss: 322.128143  [64000/68128]
Test Error: Avg loss: 503.472873 

Epoch 117
-------------------------------
loss: 352.060272  [    0/68128]
loss: 636.806580  [ 6400/68128]
loss: 459.978790  [12800/68128]
loss: 625.485352  [19200/68128]
loss: 408.545746  [25600/68128]
loss: 525.376465  [32000/68128]
loss: 528.545349  [38400/68128]
loss: 515.559448  [44800/68128]
loss: 244.632294  [51200/68128]
loss: 578.988892  [57600/68128]
loss: 376.158997  [64000/68128]
Test Error: Avg loss: 479.898765 

Epoch 118
-------------------------------
loss: 410.736877  [    0/68128]
loss: 704.246521  [ 6400/68128]
l

loss: 463.185974  [ 6400/68128]
loss: 157.213928  [12800/68128]
loss: 698.338745  [19200/68128]
loss: 487.287994  [25600/68128]
loss: 337.369232  [32000/68128]
loss: 259.389893  [38400/68128]
loss: 417.737000  [44800/68128]
loss: 202.419861  [51200/68128]
loss: 318.299255  [57600/68128]
loss: 651.949707  [64000/68128]
Test Error: Avg loss: 491.300275 

Epoch 136
-------------------------------
loss: 501.747742  [    0/68128]
loss: 297.064880  [ 6400/68128]
loss: 566.934937  [12800/68128]
loss: 261.372528  [19200/68128]
loss: 380.675354  [25600/68128]
loss: 406.645508  [32000/68128]
loss: 405.182953  [38400/68128]
loss: 365.084503  [44800/68128]
loss: 279.672272  [51200/68128]
loss: 437.219330  [57600/68128]
loss: 313.298859  [64000/68128]
Test Error: Avg loss: 482.476874 

Epoch 137
-------------------------------
loss: 350.072021  [    0/68128]
loss: 320.511505  [ 6400/68128]
loss: 215.856430  [12800/68128]
loss: 176.444351  [19200/68128]
loss: 318.478210  [25600/68128]
loss: 453.3127

loss: 291.104309  [19200/68128]
loss: 398.338715  [25600/68128]
loss: 309.687042  [32000/68128]
loss: 269.496094  [38400/68128]
loss: 395.773254  [44800/68128]
loss: 231.355118  [51200/68128]
loss: 178.991898  [57600/68128]
loss: 351.384277  [64000/68128]
Test Error: Avg loss: 489.600621 

Epoch 155
-------------------------------
loss: 340.765503  [    0/68128]
loss: 447.901031  [ 6400/68128]
loss: 396.321289  [12800/68128]
loss: 463.097107  [19200/68128]
loss: 549.796021  [25600/68128]
loss: 567.168518  [32000/68128]
loss: 306.437561  [38400/68128]
loss: 402.235352  [44800/68128]
loss: 367.662231  [51200/68128]
loss: 450.452515  [57600/68128]
loss: 290.306152  [64000/68128]
Test Error: Avg loss: 449.480578 

Epoch 156
-------------------------------
loss: 362.425903  [    0/68128]
loss: 441.659485  [ 6400/68128]
loss: 412.479553  [12800/68128]
loss: 287.781799  [19200/68128]
loss: 425.696655  [25600/68128]
loss: 256.393433  [32000/68128]
loss: 259.299652  [38400/68128]
loss: 333.7824

loss: 228.059097  [32000/68128]
loss: 328.833069  [38400/68128]
loss: 406.324005  [44800/68128]
loss: 369.031372  [51200/68128]
loss: 430.890076  [57600/68128]
loss: 163.821243  [64000/68128]
Test Error: Avg loss: 437.180982 

Epoch 174
-------------------------------
loss: 329.266449  [    0/68128]
loss: 324.845245  [ 6400/68128]
loss: 252.809692  [12800/68128]
loss: 369.688507  [19200/68128]
loss: 405.049561  [25600/68128]
loss: 282.574860  [32000/68128]
loss: 509.643402  [38400/68128]
loss: 361.177368  [44800/68128]
loss: 191.750473  [51200/68128]
loss: 386.520691  [57600/68128]
loss: 188.737976  [64000/68128]
Test Error: Avg loss: 468.622884 

Epoch 175
-------------------------------
loss: 330.916351  [    0/68128]
loss: 241.986145  [ 6400/68128]
loss: 664.537781  [12800/68128]
loss: 283.623047  [19200/68128]
loss: 659.596130  [25600/68128]
loss: 417.057861  [32000/68128]
loss: 230.184784  [38400/68128]
loss: 421.885437  [44800/68128]
loss: 278.246674  [51200/68128]
loss: 500.2498

loss: 251.970215  [44800/68128]
loss: 391.113251  [51200/68128]
loss: 515.715210  [57600/68128]
loss: 403.789246  [64000/68128]
Test Error: Avg loss: 451.647170 

Epoch 193
-------------------------------
loss: 264.555359  [    0/68128]
loss: 465.870361  [ 6400/68128]
loss: 306.836609  [12800/68128]
loss: 318.226410  [19200/68128]
loss: 349.010742  [25600/68128]
loss: 374.693695  [32000/68128]
loss: 425.768097  [38400/68128]
loss: 284.329559  [44800/68128]
loss: 439.171509  [51200/68128]
loss: 325.741028  [57600/68128]
loss: 533.108154  [64000/68128]
Test Error: Avg loss: 479.435490 

Epoch 194
-------------------------------
loss: 346.918579  [    0/68128]
loss: 376.227325  [ 6400/68128]
loss: 152.663040  [12800/68128]
loss: 356.694489  [19200/68128]
loss: 303.921143  [25600/68128]
loss: 223.307190  [32000/68128]
loss: 376.685303  [38400/68128]
loss: 256.804962  [44800/68128]
loss: 491.968506  [51200/68128]
loss: 412.120605  [57600/68128]
loss: 320.285370  [64000/68128]
Test Error: Av

loss: 274.564270  [57600/68128]
loss: 349.065979  [64000/68128]
Test Error: Avg loss: 459.903497 

Epoch 212
-------------------------------
loss: 293.937744  [    0/68128]
loss: 504.539154  [ 6400/68128]
loss: 203.678101  [12800/68128]
loss: 475.528442  [19200/68128]
loss: 302.894135  [25600/68128]
loss: 322.717590  [32000/68128]
loss: 553.647827  [38400/68128]
loss: 269.382385  [44800/68128]
loss: 289.927856  [51200/68128]
loss: 514.320557  [57600/68128]
loss: 380.569275  [64000/68128]
Test Error: Avg loss: 434.284090 

Epoch 213
-------------------------------
loss: 374.891724  [    0/68128]
loss: 236.938248  [ 6400/68128]
loss: 192.117828  [12800/68128]
loss: 328.104767  [19200/68128]
loss: 436.268921  [25600/68128]
loss: 336.342163  [32000/68128]
loss: 213.506546  [38400/68128]
loss: 454.524506  [44800/68128]
loss: 268.669403  [51200/68128]
loss: 326.685883  [57600/68128]
loss: 264.285858  [64000/68128]
Test Error: Avg loss: 485.158285 

Epoch 214
-------------------------------
l