In [1]:
from transformers import AutoTokenizer, AutoModel
import torch
from torch import nn
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.nn.utils.rnn import pad_sequence

In [2]:
train = pd.read_excel('../Data_Train.xlsx')
train.head()

Unnamed: 0,Title,Author,Edition,Reviews,Ratings,Synopsis,Genre,BookCategory,Price
0,The Prisoner's Gold (The Hunters 3),Chris Kuzneski,"Paperback,– 10 Mar 2016",4.0 out of 5 stars,8 customer reviews,THE HUNTERS return in their third brilliant no...,Action & Adventure (Books),Action & Adventure,220.0
1,Guru Dutt: A Tragedy in Three Acts,Arun Khopkar,"Paperback,– 7 Nov 2012",3.9 out of 5 stars,14 customer reviews,A layered portrait of a troubled genius for wh...,Cinema & Broadcast (Books),"Biographies, Diaries & True Accounts",202.93
2,Leviathan (Penguin Classics),Thomas Hobbes,"Paperback,– 25 Feb 1982",4.8 out of 5 stars,6 customer reviews,"""During the time men live without a common Pow...",International Relations,Humour,299.0
3,A Pocket Full of Rye (Miss Marple),Agatha Christie,"Paperback,– 5 Oct 2017",4.1 out of 5 stars,13 customer reviews,A handful of grain is found in the pocket of a...,Contemporary Fiction (Books),"Crime, Thriller & Mystery",180.0
4,LIFE 70 Years of Extraordinary Photography,Editors of Life,"Hardcover,– 10 Oct 2006",5.0 out of 5 stars,1 customer review,"For seven decades, ""Life"" has been thrilling t...",Photography Textbooks,"Arts, Film & Photography",965.62


In [3]:
class Bert(nn.Module):
    ALBERT_END_IDX = 3
    ALBERT_MAX_LEN = 512
    EMBED_FORWARD = 0
    MODEL_FORWARD = 1

    def __init__(self, modelName="albert-base-v2"):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(modelName)
        self.model = AutoModel.from_pretrained(modelName) #11683584 params
    
    def setInputsToDevice(self, inputs):
        device = self.model.device
        inputs = {key: inputs[key].to(device) for key in inputs}
        return inputs

    def truncateInputs(self, inputs):
        max_length = inputs['input_ids'].size(1)
        if max_length <= self.ALBERT_MAX_LEN:
            return inputs

        # truncate inputs at dim=1
        for key in inputs:
            inputs[key] = inputs[key][:, :self.ALBERT_MAX_LEN]
        # for those that are truncated, input_ids[:, 511] != 0
        last_ids = inputs['input_ids'][:, self.ALBERT_MAX_LEN - 1]
        # we set values at maxlen-1 to the index corresponding to the end token
        inputs['input_ids'][last_ids != 0, self.ALBERT_MAX_LEN - 1] = self.ALBERT_END_IDX
        return inputs
    
    def embeddingForward(self, inputs:dict):
        # forward pass on embeddings, out is b t u
        out = self.model.embeddings(input_ids=inputs['input_ids'], token_type_ids=inputs['token_type_ids'])
        
        # multiply attention masks to get the sum
        mask = torch.unsqueeze(inputs['attention_mask'], dim=1) * 1. # b, 1, t
        out = torch.matmul(mask, out) # b 1 u
        out = torch.squeeze(out, dim=1) # b u
        return out

    def forward(self, sentences:list, forwardType)->torch.Tensor:
        # convert sentences to a dict of input vectors
        inputs = self.tokenizer(list(sentences), return_tensors='pt',
                                    padding=True, max_length=512, truncation=True)
        # inputs = self.truncateInputs(inputs)
        inputs = self.setInputsToDevice(inputs)

        if forwardType == self.EMBED_FORWARD:
            return self.embeddingForward(inputs)

        elif forwardType == self.MODEL_FORWARD:
            return self.model(**inputs).pooler_output

        else:
            raise ValueError(f'Expected Bert.EMBED_FORWARD or Bert.MODEL_FORWARD, but got {forwardType}')

        return

# declare model here
albert = Bert().eval().cuda()

In [4]:
dataset = list(zip(train.Synopsis, train.Title))
loader = DataLoader(dataset, batch_size=64)
outs = []
with torch.no_grad():
    for title, synopsis in tqdm(loader):
        out = torch.cat([
            albert(title, albert.EMBED_FORWARD),
            albert(synopsis, albert.EMBED_FORWARD),
        ], dim=-1)
        outs.append(out.cpu())
        torch.cuda.empty_cache()
outs = torch.cat(outs)
print("Output obtained successfully, shape =", outs.shape)

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


Output obtained successfully, shape = torch.Size([6237, 256])
