In [4]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.21.0-py3-none-any.whl (4.7 MB)
[K     |████████████████████████████████| 4.7 MB 12.5 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 38.8 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 62.0 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 5.8 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstal

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Definitions

In [6]:
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import pandas as pd

class TextDataset(Dataset):
    def __init__(self, df, max_length=128):
        # 2. make the qa column
        self.text_list = df['text'].values.tolist()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_length = max_length

    def __len__(self):
        return len(self.text_list)


    def __getitem__(self, idx):
        """
        1. get text at index idx from self.text_list
        2. tokenizer & encode this text
        3. return this along with the label at index idx 
        """
        text = self.text_list[idx]
        encoded_inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        encoded_inputs = {k: v.squeeze(0) for k, v in encoded_inputs.items()}
        return encoded_inputs
        
        

In [7]:
import torch
from torch import nn
from transformers import BertModel


class MyModel(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        self.layer1 = nn.Linear(768, out_dim)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sentence_emb = output['last_hidden_state'][:, 0, :]  # [batch size, num hidden dim]

        return self.layer1(sentence_emb)

In [12]:
import numpy as np
import pandas as pd


@torch.no_grad()
def infer(model, loader, threshold=None):
    probs = []
    confidences = []
    for x in loader:
        x = {k: v.to('cuda') for k, v in x.items()}
        h = torch.softmax(model(x['input_ids'], x['attention_mask']), -1)

      
        confidence = h.max(-1)[0]
        

        probs.append(h.cpu().detach().numpy())
        confidences.append(confidence.cpu().detach().numpy())

    probs = np.concatenate(probs, 0)
    confidences = np.concatenate(confidences,0)

    preds = np.argmax(probs, 1)

    if threshold is not None:
        confidences = confidences >= threshold
    
    return preds, confidences


def topic_inference(model, data_path, topic_list, threshold = None):
    df = pd.read_csv(data_path)

    if 'topic' in df:
        index = (~df['text'].isna() & df['topic'].isna())
    else:
        index = ~df['text'].isna()
    infer_df = df[index]
    
    # load dataloader
    dataset = TextDataset(infer_df)
    dataloader = DataLoader(dataset, batch_size=64)
    
    preds, confidence_index = infer(model, dataloader, threshold = threshold)
    
    topic_preds = []
    count = 0
    for i,p in enumerate(preds):
        if confidence_index[i]:
            topic_preds.append(topic_list[p])
            count += 1
        else:
            topic_preds.append('')


    df.loc[index,'topic'] = topic_preds
    print(f"inferred {count} topics using threshold {threshold}")

    return df

## Run

In [16]:
topic_list = [
        "1 The particulate nature of matter",
        "2 Experimental techniques",
        "3 Atoms, elements and compounds",
        "4 Stoichiometry",
        "5 Electricity and chemistry",
        "6 Chemical energetics",
        "7 Chemical reactions",
        "8 Acids, bases and salts",
        "9 The Periodic Table",
        "10 Metals",
        "11 Air and water",
        "12 Sulfur",
        "13 Carbonates",
        "14 Organic chemistry",
      ]
subject = 'chemistry'
out_dim = 14

model = MyModel(out_dim)
model = model.to('cuda')

model_path = f'/content/drive/MyDrive/models/{subject}_best.pth'
data_path = 'all_data.csv'

model.load_state_dict(torch.load(model_path))

df = topic_inference(model, data_path, topic_list, threshold=0.8)

df.to_csv('all_data_labeled.csv')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


inferred 347 topics using threshold 0.8
