In [1]:
import tqdm
import pandas as pd
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from prettytable import PrettyTable

### Load Data

In [2]:
def load_data(split_name='train', columns=['text', 'label'], folder='data'):
    try:
        print(f"select [{', '.join(columns)}] columns from the {split_name} split")
        df = pd.read_csv(f'{folder}/{split_name}.csv')
        df = df.loc[:,columns]
        print("Success")
        return df
    except:
        print(f"Failed loading specified columns... Returning all columns from the {split_name} split")
        df = pd.read_csv(f'{folder}/{split_name}.csv')
        return df

valid_df = load_data('valid', columns=['id','text', 'label'], folder='data')
test_df = load_data('test_no_label', columns=['id', 'text'], folder='data')

select [id, text, label] columns from the valid split
Success
select [id, text] columns from the test_no_label split
Success


### Prepare and Load Model from checkpoint

In [3]:
class SAM(nn.Module):
    def __init__(self, class_num, device):
        super(SAM, self).__init__()
        self.class_num = class_num
        self.max_len = 256
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.pretrained_model = BertModel.from_pretrained("bert-base-uncased", return_dict=True)
        self.linear = nn.Linear(768, self.class_num, bias=True)
        self.loss_fnc = nn.CrossEntropyLoss()
        self.device = device
        
    def load_checkpoint(self, checkpoint_path=None):
        if(checkpoint_path):
            checkpoint = torch.load(checkpoint_path)
            self.load_state_dict(checkpoint['model'])
            print('model checkpoint loaded')
        
    def forward(self, samples):
        output = self.pretrained_model(samples['input_ids'].to(self.device, dtype=torch.long),
                                      samples['attention_mask'].to(self.device, dtype=torch.long),
                                      samples['token_type_ids'].to(self.device, dtype=torch.long))
        logits = self.linear(output.BaseModelOutputWithPoolingAndCrossAttentions.pooler_output)
        loss = self.loss_fnc(logits, samples['labels'].to(self.device, dtype=torch.float))
        return {'logits':logits, 'loss':loss}

    def predict(self, text):
        inputs = self.tokenizer.encode_plus(text,
                                            add_special_tokens=True,
                                            max_length=self.max_len,
                                            padding='max_length', 
                                            truncation=True,
                                            return_attention_mask=True,
                                            return_tensors='pt')
        with torch.no_grad():
            output = self.pretrained_model(inputs['input_ids'].flatten().unsqueeze(0).to(self.device, dtype=torch.long),
                                  inputs['attention_mask'].flatten().unsqueeze(0).to(self.device, dtype=torch.long),
                                  inputs['token_type_ids'].flatten().unsqueeze(0).to(self.device, dtype=torch.long)).pooler_output
        logits = self.linear(output)
        label = logits[0].softmax(0).argmax() + 1
        return {'label': label}
    
def count_trainable_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    
model = SAM(class_num=5, device=device)
model.to(device)
model.load_checkpoint(checkpoint_path="finetune_bert_linear_checkpoint.pth")

model checkpoint loaded


### Validation Set Predictions

In [5]:
val_data=[]
for index, row in enumerate(tqdm.tqdm(valid_df.iterrows())):
  predict = model.predict(text=row[1]['text'])
  val_data.append({'id': row[1]['id'], 'label': predict['label'].cpu().numpy(),})

2000it [01:16, 26.30it/s]


In [6]:
valid_pred = pd.DataFrame(val_data)
valid_pred.head()

Unnamed: 0,id,label
0,A29WNXUH97IH13_6848,3
1,A1FV0HOXQA87O8_10874,4
2,A3XZ7FSPXP9S4_745,4
3,A2W2O4WH9VZCXQ_15439,5
4,A24Y7A0B20RWT9_23351,4


In [7]:
valid_pred.to_csv(f'data/valid_pred.csv', index=False)

### Test Set Predictions

In [8]:
test_data=[]
for index, row in enumerate(tqdm.tqdm(test_df.iterrows())):
  predict = model.predict(text=row[1]['text'])
  test_data.append({'id': row[1]['id'], 'label': predict['label'].cpu().numpy(),})

4000it [02:28, 26.90it/s]


In [9]:
test_pred = pd.DataFrame(test_data)
test_pred.head()

Unnamed: 0,id,label
0,A3EMGD8RAEOK64_2907,3
1,A2BOWU2PX28BET_5501,5
2,A100WO06OQR8BQ_10469,1
3,A2H4LKU7CPIUU9_11364,1
4,A14RF11JYGDKI8_23751,3


In [10]:
test_pred.to_csv(f'data/pred.csv', index=False)