In [16]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from datasets import Dataset
from transformers import DistilBertModel
from torch.nn.functional import mse_loss


In [15]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
dev = Dataset.load_from_disk("data/dev")
dev_df = pd.DataFrame(dev)

In [6]:
dev_df

Unnamed: 0,par_id,art_id,keyword,country_code,text,labels,pcl,label_category_vector,input_ids,attention_mask
0,107,@@16900972,homeless,ke,"His present "" chambers "" may be quite humble ,...",3.0,1,"[0, 0, 0, 0, 0, 0, 1]","[101, 2010, 2556, 1000, 8477, 1000, 2089, 2022...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,149,@@1387882,disabled,us,Krueger recently harnessed that creativity to ...,2.0,1,"[1, 0, 0, 0, 0, 0, 1]","[101, 1047, 6820, 26320, 3728, 17445, 2098, 20...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,151,@@19974860,poor-families,in,10:41am - Parents of children who died must ge...,3.0,1,"[1, 0, 0, 1, 0, 0, 0]","[101, 2184, 1024, 4601, 3286, 1011, 3008, 1997...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,154,@@20663936,disabled,ng,When some people feel causing problem for some...,4.0,1,"[0, 0, 1, 1, 1, 1, 0]","[101, 2043, 2070, 2111, 2514, 4786, 3291, 2005...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,157,@@21712008,poor-families,ca,We are alarmed to learn of your recently circu...,4.0,1,"[1, 1, 0, 0, 1, 1, 0]","[101, 2057, 2024, 19260, 2000, 4553, 1997, 211...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...,...,...,...,...,...
2088,10463,@@4676355,refugee,pk,""" The Pakistani police came to our house and t...",0.0,0,"[0, 0, 0, 0, 0, 0, 0]","[101, 1000, 1996, 9889, 2610, 2234, 2000, 2256...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2089,10464,@@19612634,disabled,ie,When Marie O'Donoghue went looking for a speci...,0.0,0,"[0, 0, 0, 0, 0, 0, 0]","[101, 2043, 5032, 1051, 1005, 2123, 8649, 2016...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2090,10465,@@14297363,women,lk,Sri Lankan norms and culture inhibit women fro...,1.0,0,"[0, 0, 0, 0, 0, 0, 0]","[101, 5185, 16159, 17606, 1998, 3226, 26402, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2091,10466,@@70091353,vulnerable,ph,He added that the AFP will continue to bank on...,0.0,0,"[0, 0, 0, 0, 0, 0, 0]","[101, 2002, 2794, 2008, 1996, 21358, 2361, 209...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [17]:
## Define Model module here ##
class CustomBert(nn.Module):
    def __init__(self, transformer_out=6, dropout=0.1, class_weights=None):
        super(CustomBert, self).__init__()
        # Instead of just using the output of the final hidden layer,
        # you can also pass in a range of hidden layers to concatenate their outputs
        self.transformer_out = (
            range(transformer_out, transformer_out + 1)
            if isinstance(transformer_out, int)
            else transformer_out
        )
        out_dim = len(self.transformer_out) * 768

        # Use pretrained DistilBert. Force it to use our dropout
        self.distilbert = DistilBertModel.from_pretrained(
            "distilbert-base-uncased", output_hidden_states=True
        )  # type: DistilBertModel
        for module in self.distilbert.modules():
            if isinstance(module, torch.nn.Dropout):
                module.p = dropout

        # Then apply a dense hidden layer down to 768, and a final layer down to 1
        self.feedforward = nn.Sequential(
            nn.Linear(out_dim, 768),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(768, 1),
        )

        if class_weights is not None:
            self.class_weights = class_weights
            self.pos_weight = class_weights[1] / class_weights[0]

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)

        # Recommended pooling approach for DistilBert is to average over the hidden state sequence
        # instead of outputs.last_hidden_state[:, 0], which is used for Bert which uses [CLS] token
        pooled_output = []
        for i in self.transformer_out:
            hs = outputs.hidden_states[i]
            mask = attention_mask.unsqueeze(-1)
            hs = hs * mask
            mean_hs = hs.sum(dim=1) / mask.sum(dim=1)
            pooled_output.append(mean_hs)

        # We also concatenate the outputs of multiple layers if chosen by the user
        cat_output = torch.cat(pooled_output, dim=1)

        # Apply dense feedforward
        y = self.feedforward(cat_output).squeeze(-1)

        # Outside the Trainer, we return the predictions
        if labels is None:
            return y

        # Inside the Trainer, we also need to return the loss
        global binary_classifier
        if binary_classifier:
            loss = F.binary_cross_entropy_with_logits(
                y, labels, pos_weight=self.pos_weight
            ).to(DEVICE)
        else:
            loss = mse_loss(y, labels, reduction="none").to(DEVICE)
            weights = self.class_weights[labels.long().to(DEVICE)]
            loss = loss * weights
            loss = loss.mean()
        return loss, y

    def freeze(self):
        for param in self.distilbert.parameters():
            param.requires_grad = False

    def unfreeze(self, layer=None):
        for name, param in self.distilbert.named_parameters():
            if layer is None or name.startswith(f"transformer.layer.{layer}"):
                param.requires_grad = True

In [43]:
model = CustomBert() # instantiate model

# load in trained parameters
checkpoint_fp = 'results/model.pth'
checkpoint = torch.load(checkpoint_fp)
model.load_state_dict(checkpoint)
model.eval()

CustomBert(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear

In [47]:
batch_size = 5
atten_masks = torch.LongTensor([ls for ls in dev_df.attention_mask.values[:batch_size]])
input_ids = torch.LongTensor([ls for ls in dev_df.input_ids.values[:batch_size]])
model(input_ids, atten_masks)

tensor([1.0970, 1.8491, 2.2538, 2.6964, 2.0818], grad_fn=<SqueezeBackward1>)