## Definitions

In [None]:
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import pandas as pd

class TextDataset(Dataset):
    def __init__(self, df, max_length=128):
        # 2. make the qa column
        self.text_list = df['text'].values.tolist()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_length = max_length

    def __len__(self):
        return len(self.text_list)


    def __getitem__(self, idx):
        """
        1. get text at index idx from self.text_list
        2. tokenizer & encode this text
        3. return this along with the label at index idx 
        """
        text = self.text_list[idx]
        encoded_inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        encoded_inputs = {k: v.squeeze(0) for k, v in encoded_inputs.items()}
        return encoded_inputs
        
        

In [None]:
import torch
from torch import nn
from transformers import BertModel


class MyModel(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        self.layer1 = nn.Linear(768, out_dim)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sentence_emb = output['last_hidden_state'][:, 0, :]  # [batch size, num hidden dim]

        return self.layer1(sentence_emb)

In [None]:
import numpy as np
import pandas as pd


@torch.no_grad()
def infer(model, loader, threshold=None):
    probs = []
    confidences = []
    for x in loader:
        x = {k: v.to('cuda') for k, v in x.items()}
        h = torch.softmax(model(x['input_ids'], x['attention_mask']), -1)

      
        confidence = h.max(-1)[0]
        

        probs.append(h.cpu().detach().numpy())
        confidence.append(confidences.cpu().detach().numpy())

    probs = np.concatenate(probs, 0)
    confidences = np.concatenate(confidences,0)

    preds = np.argmax(probs, 1)

    if threshold is not None:
        confidences = confidences >= threshold
    
    return preds, confidences


def topic_inference(model, data_path, threshold = None):
    df = pd.read_csv(data_path)

    index = (~df['text'].isna() & df['topic'].isna())
    infer_df = df[index]
    
    # load dataloader
    dataset = TextDataset(infer_df)
    dataloader = DataLoader(dataset, batch_size=64)
    
    preds, confidence_index = infer(model, dataloader, threshold = threshold)
    topic_list = [
        "1 The particulate nature of matter",
        "2 Experimental techniques",
        "3 Atoms, elements and compounds",
        "4 Stoichiometry",
        "5 Electricity and chemistry",
        "6 Chemical energetics",
        "7 Chemical reactions",
        "8 Acids, bases and salts",
        "9 The Periodic Table",
        "10 Metals",
        "11 Air and water",
        "12 Sulfur",
        "13 Carbonates",
        "14 Organic chemistry",
      ]
    
    topic_preds = []
    for i,p in enumerate(preds):
        if confidence_index[i]:
            topic_preds.append(topic_list[p])
        else:
            topic_preds.append('')


    df.loc[index,'topic'] = topic_preds

    return df

## Run

In [None]:
model = MyModel(14)

model_path = '/content/drive/MyDrive/models/best.pth'
data_path = 'all_data.csv'

model.load_state_dict(torch.load(model_path))

df = topic_inference(model, data_path)

df.to_csv(data_path)