In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR

from torch.utils.data import DataLoader
from BookDataset import getBookDataset, BookDataset
from embeddings import EmbeddingSummation

In [2]:
train, val, test, featureSizes = getBookDataset('../Data_Train.xlsx', '../Data_Test.xlsx')

In [3]:
class BookPrice(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        outDims = lambda x: int(x**.25) + 1
        self.outDims = outDims
        
        # embeddings for categorical variables: Author, Genre, BookCategory
        self.AuthorEmbedding = nn.Embedding(config.Author,  outDims(config.Author))
        self.GenreEmbedding  = nn.Embedding(config.Genre,  outDims(config.Genre))
        self.BookCategoryEmbedding  = nn.Embedding(config.BookCategory,  outDims(config.BookCategory))
        categoricalUnits = 5 * outDims(config.Author) + outDims(config.Genre) + outDims(config.BookCategory)
        # ------
        
        
        # embeddings for text features Title, Synopsis
        self.TitleEmbedding    = EmbeddingSummation()
        self.SynopsisEmbedding = EmbeddingSummation()
        textUnits = 128 * 2
        #-------
        
        # Ratings, Reviews and Edition
        numericUnits = 3
        #-------

        total = categoricalUnits + textUnits + numericUnits
        
        self.Dense = nn.Sequential(
            nn.Linear(total, 64), nn.LayerNorm(64), nn.Tanh(), nn.Dropout(.2),
            nn.Linear(64, 32), nn.LayerNorm(32), nn.Tanh(), nn.Dropout(.2),
            nn.Linear(32, 1)
        )

    def forward(self, batch:dict) -> torch.Tensor:
        authors = torch.stack([ batch[f'author_{i}'] for i in range(5) ]).T
        authDim = 5 * self.outDims(self.config.Author)
        categoricals = [
            self.AuthorEmbedding(authors).reshape(-1, authDim),
            self.GenreEmbedding(batch['Genre']),
            self.BookCategoryEmbedding(batch['BookCategory'])
        ]
        
        text = [
            self.TitleEmbedding(batch['Title']),
            self.SynopsisEmbedding(batch['Synopsis'])
        ]
        
        numericals  = [torch.stack([
            batch['Ratings'],
            batch['Reviews'],
            batch['Edition']
        ]).T]
        
        inputs = torch.cat(categoricals + text + numericals, dim=-1)
        
        logits = self.Dense(inputs)
        return logits
    
    
    def training_step(self, batch:dict, batchIdx:int)->torch.Tensor:
        logits = self(batch)

        loss = F.mse_loss(logits, batch['Price'].reshape(-1, 1))
        self.log('loss', loss)

        return loss

    def validation_step(self, batch:dict, batchIdx:int):
        logits = self(batch)
        loss = F.mse_loss(logits, batch['Price'].reshape(-1, 1))
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.99 * epoch)
        return [optimizer], [scheduler]

    
model = BookPrice(featureSizes)
x = sum(p.numel() for p in model.parameters())
print(f'BookPrice model has {x:,} parameters')

BookPrice model has 7,836,458 parameters


In [4]:
stopping = EarlyStopping(monitor='val_loss', patience=10, mode='min')
ckpt = ModelCheckpoint(dirpath='checkpoints', filename='{epoch}-{val_loss:.5f}',
                       monitor='val_loss', mode='min',
                      save_weights_only=True, verbose=True)

trainer = pl.Trainer(gpus = 1, callbacks=[stopping, ckpt])

loaders = []
for i, data in enumerate([train, val, test]):
    shuffle = i==0
    loaders.append( DataLoader(BookDataset(data), batch_size=16, shuffle=shuffle ) )

trainer.fit(model, loaders[0], loaders[1])

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                  | Type               | Params
-------------------------------------------------------------
0 | AuthorEmbedding       | Embedding          | 2.6 K 
1 | GenreEmbedding        | Embedding          | 804   
2 | BookCategoryEmbedding | Embedding          | 24    
3 | TitleEmbedding        | EmbeddingSummation | 3.9 M 
4 | SynopsisEmbedding     | EmbeddingSummation | 3.9 M 
5 | Dense                 | Sequential         | 20.9 K
-------------------------------------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 0, global step 311: val_loss reached 45.44375 (best 45.44375), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=0-val_loss=45.44375.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 1, global step 623: val_loss reached 0.61094 (best 0.61094), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=1-val_loss=0.61094.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 2, global step 935: val_loss reached 0.60454 (best 0.60454), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=2-val_loss=0.60454.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 3, global step 1247: val_loss reached 0.57148 (best 0.57148), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=3-val_loss=0.57148.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 4, global step 1559: val_loss reached 0.45826 (best 0.45826), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=4-val_loss=0.45826.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 5, global step 1871: val_loss reached 0.35972 (best 0.35972), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=5-val_loss=0.35972.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 6, step 2183: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 7, global step 2495: val_loss reached 0.32916 (best 0.32916), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=7-val_loss=0.32916.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 8, step 2807: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 9, step 3119: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 10, global step 3431: val_loss reached 0.32819 (best 0.32819), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=10-val_loss=0.32819.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 11, step 3743: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 12, step 4055: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 13, step 4367: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 14, step 4679: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 15, step 4991: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 16, step 5303: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 17, step 5615: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 18, global step 5927: val_loss reached 0.32766 (best 0.32766), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=18-val_loss=0.32766.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 19, step 6239: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 20, step 6551: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 21, global step 6863: val_loss reached 0.32656 (best 0.32656), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=21-val_loss=0.32656.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 22, step 7175: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 23, global step 7487: val_loss reached 0.32536 (best 0.32536), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=23-val_loss=0.32536.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 24, global step 7799: val_loss reached 0.31857 (best 0.31857), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=24-val_loss=0.31857.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 25, step 8111: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 26, step 8423: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 27, step 8735: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 28, step 9047: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 29, step 9359: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 30, step 9671: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 31, global step 9983: val_loss reached 0.31622 (best 0.31622), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=31-val_loss=0.31622.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 32, step 10295: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 33, step 10607: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 34, global step 10919: val_loss reached 0.31449 (best 0.31449), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=34-val_loss=0.31449.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 35, step 11231: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 36, step 11543: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 37, step 11855: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 38, step 12167: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 39, step 12479: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 40, step 12791: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 41, global step 13103: val_loss reached 0.30841 (best 0.30841), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=41-val_loss=0.30841.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 42, global step 13415: val_loss reached 0.30309 (best 0.30309), saving model to "C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=42-val_loss=0.30309.ckpt" as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 43, step 13727: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 44, step 14039: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 45, step 14351: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 46, step 14663: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 47, step 14975: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 48, step 15287: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 49, step 15599: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 50, step 15911: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 51, step 16223: val_loss was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 52, step 16535: val_loss was not in top 1





1

In [5]:
print("Saving model from best ckpt", ckpt.best_model_path)
model = BookPrice.load_from_checkpoint(ckpt.best_model_path, config=featureSizes).eval().cuda()
results = []
with torch.no_grad():
    for batch in loaders[2]:
        for key in batch:
            if key not in 'Title Synopsis'.split():
                batch[key] = batch[key].to(model.device)
        preds = model(batch)
        results.append(preds)
results = torch.cat(results).cpu().numpy().reshape(-1)
submission = pd.DataFrame({
    'Price': np.exp(results) - 1
})
print("Saving submission ", submission.shape)
submission.to_excel('submission.xlsx', index=False)

Saving model from best ckpt C:\Users\Deepak H R\Desktop\data\BookPrice\albert\chekpoints\epoch=42-val_loss=0.30309.ckpt
Saving submission  (1560, 1)
