In [1]:
from BookDataset import getBookDataset
import pytorch_lightning as pl
from torch import nn
import torch
import matplotlib.pyplot as plt
import pandas as pd
import torch.nn.functional as F
import numpy as np
from torch.distributions import LogNormal
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

In [2]:
train, val, test = getBookDataset('Data_Train.xlsx', 'Data_Test.xlsx')

In [3]:
train.frame['Author Genre BookCategory Edition'.split()].max() + 1

Author          482
Genre           201
BookCategory     12
Edition           2
dtype: int64

In [4]:
train.frame['Reviews Ratings'.split()].describe()

Unnamed: 0,Reviews,Ratings
count,4989.0,4989.0
mean,4.289316,2.047692
std,0.670264,1.563154
min,1.0,0.0
25%,4.0,0.693147
50%,4.4,1.94591
75%,4.8,3.091043
max,5.0,8.714403


In [5]:
class BookPrice(pl.LightningModule):
    def __init__(self):
        # Author Genre BookCategory Edition
        super().__init__()
        self.categoricals = 'Author Genre BookCategory Edition'.split()
        inDims  = [482, 201, 12, 2]
        outDims = lambda x: int(x**.25) + 1

        for key, dim  in zip(self.categoricals, inDims):
            out_dim = outDims(dim)
            setattr(self, key, nn.Embedding(dim, out_dim))

        self.Sequential = nn.Sequential(
            nn.Linear(15, 64), nn.LayerNorm(64), nn.Tanh(),
            nn.Linear(64, 32), nn.LayerNorm(32), nn.Tanh(), 
            nn.Linear(32, 2)
        )

    def forward(self, batch:dict) -> torch.Tensor:
        categoricals = torch.cat([
            getattr(self, key)(batch[key])
            for key in self.categoricals
        ], dim=-1)

        numericals  = torch.stack([
            batch['Reviews'], batch['Ratings']
        ]).T

        inputs = torch.cat([
            categoricals, numericals
        ], dim=-1)

        logits = self.Sequential(inputs)
        mu    = logits[:, 0]
        sigma = F.softplus(logits[:, 1])
        return mu, sigma
    
    def predict(self, mu:torch.Tensor, sigma:torch.Tensor) -> torch.Tensor:
        return torch.exp(mu + .5 * sigma ** 2)
    
    def training_step(self, batch:dict, batchIdx:int)->torch.Tensor:
        mu, sigma = self(batch)

        # calculate training loss
        safeSigma = torch.max(sigma, torch.ones_like(sigma)*1e-7)
        loss = - LogNormal(mu, safeSigma).log_prob(batch['Price']).mean()
        self.log('loss', loss)

        
        return loss

    def validation_step(self, batch:dict, batchIdx:int):
        mu, sigma = self(batch)

        # calculate rmse for validation batch
        preds = self.predict(mu, sigma)
        error = preds - batch['Price']
        error = error.pow(2).mean()
        self.log('val_rmse', error)
        return error
    
    def test_step(self, batch:dict, batchIdx:int):
        mu, sigma = self(batch)
        preds = self.predict(mu, sigma)
        return preds

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.99 * epoch)
        return [optimizer], [scheduler]

model = BookPrice()
trainer = pl.Trainer(gpus = 1, max_epochs=50)

loaders = []
for i, data in enumerate([train, val, test]):
    shuffle = i==0
    loaders.append( DataLoader(data, batch_size=32, shuffle=shuffle ) )

trainLoader, valLoader, testLoader = loaders
trainer.fit(model, trainLoader, valLoader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type       | Params
--------------------------------------------
0 | Author       | Embedding  | 2 K   
1 | Genre        | Embedding  | 804   
2 | BookCategory | Embedding  | 24    
3 | Edition      | Embedding  | 4     
4 | Sequential   | Sequential | 3 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [6]:
trainer.test(model, testLoader)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------



1

In [7]:
results = []
model = model.eval()
with torch.no_grad():
    for batch in testLoader:
        for key in batch:
            batch[key] = batch[key].to(model.device)
        preds = model.test_step(batch, 0)
        results.append(preds)
results = torch.cat(results).cpu().numpy()

In [8]:
pd.DataFrame({
    'Price': results
}).to_excel('mysubmission_tanh.xlsx', index=False)