This notebook contains an example of how to use pretrained language model, as well as how to extract the text representation from the pretrained model. You can use google colab or kaggle for free GPU resource. 

kaggle: https://www.kaggle.com

In [1]:
# you can use this command to install libraries you don't have
#!pip install transformers 



In [2]:
import pandas as pd
from transformers import DistilBertTokenizer, DistilBertModel
import torch
import torch.nn.functional as F

## Check if GPU is available 

In [3]:
torch.cuda.is_available()

False

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cpu')

## Read data from csv

In [5]:
train_df = pd.read_csv('C:/Users/Andreas/Downloads/dataset/training_sent.csv')

In [6]:
example = train_df.citation_context.values[0]
example

'Graph cuts optimization was applied to the output of a hierarchical two-stage classifier, which was trained to identify the cartilage and bone voxels MAINCIT . As the underlying multi-label graph cuts jointly consider independent classifier outputs for cartilage, bone, and background, label-conflict-resolution may be challenging in regions with multiple labels.'

## Load tokenizer and pretrained language model
You can find more pretrained language models on: https://huggingface.co/models

In [7]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

In [8]:
model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(DEVICE)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Using tokenizer to get the token ids

In [9]:
example_input = tokenizer(example, return_tensors = 'pt', max_length = 512, truncation = True, padding = 'max_length').to(DEVICE)
example_input

{'input_ids': tensor([[  101, 10629,  7659, 20600,  2001,  4162,  2000,  1996,  6434,  1997,
          1037, 25835,  2048,  1011,  2754,  2465, 18095,  1010,  2029,  2001,
          4738,  2000,  6709,  1996, 11122, 11733,  3351,  1998,  5923, 29450,
          9050,  2364, 26243,  1012,  2004,  1996, 10318,  4800,  1011,  3830,
         10629,  7659, 10776,  5136,  2981,  2465, 18095, 27852,  2005, 11122,
         11733,  3351,  1010,  5923,  1010,  1998,  4281,  1010,  3830,  1011,
          4736,  1011,  5813,  2089,  2022, 10368,  1999,  4655,  2007,  3674,
         10873,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  

## Obtaining the text representation from DistilBERT

In [10]:
model.train(False)
example_representation = model(**example_input).last_hidden_state
example_representation

tensor([[[-0.5604, -0.2248, -0.3066,  ..., -0.2467, -0.0082,  0.4634],
         [-0.1042, -0.3635, -0.2408,  ...,  0.3454,  0.4058,  0.6786],
         [ 0.0271, -0.1428, -0.1220,  ...,  0.0729, -0.0081, -0.0783],
         ...,
         [-0.2101, -0.1302,  0.1094,  ..., -0.2197, -0.1336,  0.1365],
         [ 0.0850,  0.1100,  0.0469,  ..., -0.4255, -0.2457,  0.1477],
         [-0.3461,  0.0067,  0.0949,  ..., -0.2213, -0.0230, -0.0052]]],
       grad_fn=<NativeLayerNormBackward0>)

In [11]:
example_representation.size()

torch.Size([1, 512, 768])

In [12]:
example_representation_mean = example_representation.mean(dim = 1)
example_representation_mean.size()

torch.Size([1, 768])

## Reading paper from txt file

In [13]:
with open ('C:/Users/Andreas/Downloads/dataset/papers/1307.2965.txt', encoding='utf-8') as file:
    paper = file.readlines()
paper = [line.strip() for line in paper]
paper

['[CLS]  Semantic Context Forests for Learning-Based Knee Cartilage Segmentation in 3D MR Images Learning-Based Knee Cartilage Segmentation Quan WangFORMULA   [SEP]',
 '[CLS] Dijia WuFORMULA   [SEP]',
 '[CLS] Le LuFORMULA   [SEP]',
 '[CLS] Meizhu Liu FORMULA Kim L. BoyerFORMULA   [SEP]',
 '[CLS] Shaohua Kevin ZhouFORMULA  Q. Wang D. Wu [SEP]',
 '[CLS] L. Lu M. Liu K. L. Boyer S. K. Zhou FORMULA Siemens Corporate Research, Princeton, NJ 08540, USA FORMULA Rensselaer Polytechnic Institute, Troy, NY 12180, USA Learning-Based Knee Cartilage Segmentation Quan Wang, et al. The automatic segmentation of human knee cartilage from 3D MR images is a useful yet challenging task due to the thin sheet structure of the cartilage with diffuse boundaries and inhomogeneous intensities. [SEP]',
 '[CLS] In this paper, we present an iterative multi-class learning method to segment the femoral, tibial and patellar cartilage simultaneously, which effectively exploits the spatial contextual constraints betwe

## Obtaining text representation fo the whole paper

In [14]:
sent_sum = torch.zeros(1, 512, 768).to(DEVICE)
for sent in paper:
    sent_input = tokenizer(sent, return_tensors = 'pt', max_length = 512, truncation = True, padding = 'max_length').to(DEVICE)
    model.train(False)
    with torch.no_grad():
        sent_sum  += model(**sent_input).last_hidden_state

In [15]:
paper_mean = sent_sum.mean(dim=1)

## Calculating cosine similarity between the paper and citation context

In [16]:
cosine_similarity = F.cosine_similarity(example_representation_mean, paper_mean)

In [17]:
cosine_similarity.item()

0.9062259793281555