In [1]:
# https://github.com/nlpyang/PreSumm → summarizer

In [2]:
import os
import torch
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from torch.nn.utils.rnn import pad_sequence
from transformers import BertForSequenceClassification, AdamW

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [4]:
class CoLADataset(Dataset):
    def __init__(self, path, tokenizer, is_train=True):
        if is_train:
            filename = os.path.join(path, 'raw/in_domain_train.tsv')
        else:
            filename = os.path.join(path, 'raw/in_domain_dev.tsv')
        df = pd.read_csv(filename, \
                               sep='\t', \
                               names=['source', 'label', 'judgement', 'text'])
        self.input_ids = []
        self.token_type_ids = []
        self.attention_mask = []
        for t in df.text:
            inp = tokenizer(t, return_tensors='pt')
            self.input_ids.append(inp['input_ids'])
            self.token_type_ids.append(inp['token_type_ids'])
            self.attention_mask.append(inp['attention_mask'])
        self.label = df.label
        
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return [self.input_ids[idx], self.token_type_ids[idx], self.attention_mask[idx], self.label[idx]]

In [5]:
train_dataset = CoLADataset('../../data/cola_classification', tokenizer)
eval_dataset = CoLADataset('../../data/cola_classification', tokenizer, is_train=False)

In [6]:
len(train_dataset), len(eval_dataset)

(8551, 527)

In [7]:
def collate_fn(batch):
    input_ids = [b[0][0] for b in batch]
    token_type_ids = [b[1][0] for b in batch]
    attention_mask = [b[2][0] for b in batch]
    label = torch.tensor([b[3] for b in batch])
    input_ids = pad_sequence(input_ids, batch_first=True)
    token_type_ids = pad_sequence(token_type_ids, batch_first=True)
    attention_mask = pad_sequence(attention_mask, batch_first=True)
    return input_ids, token_type_ids, attention_mask, label

In [8]:
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=16, shuffle=True)

In [9]:
for i, d in enumerate(train_dataloader):
    if i > 10:
        break
    input_ids, token_type_ids, attention_mask, labels = d
    print(input_ids.shape, token_type_ids.shape, attention_mask.shape, labels.shape)

torch.Size([16, 26]) torch.Size([16, 26]) torch.Size([16, 26]) torch.Size([16])
torch.Size([16, 17]) torch.Size([16, 17]) torch.Size([16, 17]) torch.Size([16])
torch.Size([16, 15]) torch.Size([16, 15]) torch.Size([16, 15]) torch.Size([16])
torch.Size([16, 17]) torch.Size([16, 17]) torch.Size([16, 17]) torch.Size([16])
torch.Size([16, 15]) torch.Size([16, 15]) torch.Size([16, 15]) torch.Size([16])
torch.Size([16, 20]) torch.Size([16, 20]) torch.Size([16, 20]) torch.Size([16])
torch.Size([16, 20]) torch.Size([16, 20]) torch.Size([16, 20]) torch.Size([16])
torch.Size([16, 12]) torch.Size([16, 12]) torch.Size([16, 12]) torch.Size([16])
torch.Size([16, 23]) torch.Size([16, 23]) torch.Size([16, 23]) torch.Size([16])
torch.Size([16, 32]) torch.Size([16, 32]) torch.Size([16, 32]) torch.Size([16])
torch.Size([16, 28]) torch.Size([16, 28]) torch.Size([16, 28]) torch.Size([16])


In [10]:
# Load BertForSequenceClassification, the pretrained BERT model with a single 
# linear classification layer on top. 
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels = 2)
model.train()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [11]:
optimizer = AdamW(model.parameters(), lr = 2e-5, eps = 1e-8)
loss = nn.CrossEntropyLoss()

In [12]:
def train(model, dataloader, optimizer):
    for i, d in enumerate(dataloader):
        optimizer.zero_grad()
        input_ids, token_type_ids, attention_mask, labels = d
        out = model(input_ids=input_ids, labels=labels)
        loss = out[0]
        
        #print('before backward: {}'.format(loss))
        loss.backward()
        print('after backward: {}'.format(loss))
        
        optimizer.step()
        
        #break

In [13]:
train(model, train_dataloader, optimizer)

after backward: 0.682746946811676
after backward: 0.5892165303230286
after backward: 0.5971190929412842
after backward: 0.5736703276634216
after backward: 0.6826539039611816
after backward: 0.594929575920105
after backward: 0.7404893040657043
after backward: 0.6180349588394165
after backward: 0.6880731582641602
after backward: 0.6873162984848022
after backward: 0.7961403727531433
after backward: 0.6500728726387024
after backward: 0.6167759299278259
after backward: 0.5961862802505493
after backward: 0.5941612720489502
after backward: 0.5366067886352539
after backward: 0.7621889114379883
after backward: 0.49336758255958557
after backward: 0.7208257913589478
after backward: 0.5851196646690369
after backward: 0.576001763343811
after backward: 0.5024921298027039
after backward: 0.4966876208782196
after backward: 0.6248321533203125
after backward: 0.6202273964881897
after backward: 0.5783764123916626
after backward: 0.6319259405136108
after backward: 0.6065897941589355
after backward: 0.5712

after backward: 0.5549787878990173
after backward: 0.4486159086227417
after backward: 0.602379560470581
after backward: 0.36974450945854187
after backward: 0.4278520941734314
after backward: 0.3878175914287567
after backward: 0.48019644618034363
after backward: 0.5778515338897705
after backward: 0.36958763003349304
after backward: 0.6079637408256531
after backward: 0.5354700088500977
after backward: 0.5132678151130676
after backward: 0.4545446038246155
after backward: 0.5322014093399048
after backward: 0.5130562782287598
after backward: 0.6457064747810364
after backward: 0.4782536029815674
after backward: 0.484110027551651
after backward: 0.3198695182800293
after backward: 0.4555729627609253
after backward: 0.4262423515319824
after backward: 0.4781089127063751
after backward: 0.3976157009601593
after backward: 0.36007121205329895
after backward: 0.30758973956108093
after backward: 0.45400115847587585
after backward: 0.34655269980430603
after backward: 0.8499913811683655
after backward:

after backward: 0.25974270701408386
after backward: 0.33952850103378296
after backward: 0.23443208634853363
after backward: 0.5619715452194214
after backward: 0.41095876693725586
after backward: 0.3504435420036316
after backward: 0.19272997975349426
after backward: 0.5429831147193909
after backward: 0.416250079870224
after backward: 0.44957268238067627
after backward: 0.32560616731643677
after backward: 0.28038734197616577
after backward: 0.5670682191848755
after backward: 0.37773725390434265
after backward: 0.433677077293396
after backward: 0.23663640022277832
after backward: 0.3985004127025604
after backward: 0.6404274702072144
after backward: 0.42583486437797546
after backward: 0.4338298439979553
after backward: 0.3775867223739624
after backward: 0.9495734572410583
after backward: 0.7629899382591248
after backward: 0.2331821471452713
after backward: 0.37243711948394775
after backward: 0.2740698754787445
after backward: 0.5311015844345093
after backward: 0.30894342064857483
after bac

In [14]:
train(model, train_dataloader, optimizer)

after backward: 0.09486492723226547
after backward: 0.24622975289821625
after backward: 0.4799994230270386
after backward: 0.27298837900161743
after backward: 0.08829061686992645
after backward: 0.6410375237464905
after backward: 0.08769513666629791
after backward: 0.10987889021635056
after backward: 0.13490232825279236
after backward: 0.42762675881385803
after backward: 0.3240242302417755
after backward: 0.1917620301246643
after backward: 0.08079755306243896
after backward: 0.44968703389167786
after backward: 0.2532300055027008
after backward: 0.5644389390945435
after backward: 0.24666732549667358
after backward: 0.1352958232164383
after backward: 0.12333741784095764
after backward: 0.10586340725421906
after backward: 0.4008282423019409
after backward: 0.40094703435897827
after backward: 0.39249175786972046
after backward: 0.10237585753202438
after backward: 0.1559070497751236
after backward: 0.2538176476955414
after backward: 0.23042088747024536
after backward: 0.4617820978164673
aft

after backward: 0.23587512969970703
after backward: 0.03916380926966667
after backward: 0.2509854733943939
after backward: 0.06710845232009888
after backward: 0.29833483695983887
after backward: 0.47011780738830566
after backward: 0.3735921084880829
after backward: 0.5312666893005371
after backward: 0.06500791758298874
after backward: 0.09102996438741684
after backward: 0.30447256565093994
after backward: 0.0809883326292038
after backward: 0.3997194766998291
after backward: 0.3662208020687103
after backward: 0.37098807096481323
after backward: 0.22690531611442566
after backward: 0.13482500612735748
after backward: 0.39417678117752075
after backward: 0.21313311159610748
after backward: 0.24930445849895477
after backward: 0.12667474150657654
after backward: 0.38752150535583496
after backward: 0.2998619079589844
after backward: 0.12131449580192566
after backward: 0.4092526435852051
after backward: 0.12759043276309967
after backward: 0.16821175813674927
after backward: 0.1995803713798523
a

after backward: 0.057256292551755905
after backward: 0.30755484104156494
after backward: 0.102604940533638
after backward: 0.4680202305316925
after backward: 0.37565845251083374
after backward: 0.3067375719547272
after backward: 0.194517120718956
after backward: 0.40661224722862244
after backward: 0.15022581815719604
after backward: 0.2618543803691864
after backward: 0.06553870439529419
after backward: 0.4948265552520752
after backward: 0.05641011893749237
after backward: 0.2665047347545624
after backward: 0.05750419944524765
after backward: 0.34798988699913025
after backward: 0.13552261888980865
after backward: 0.26064935326576233
after backward: 0.5246537923812866
after backward: 0.6139988303184509
after backward: 0.14062006771564484
after backward: 0.3377191424369812
after backward: 0.4448888897895813
after backward: 0.0773424431681633
after backward: 0.42858681082725525
after backward: 0.49241510033607483
after backward: 0.2543289363384247
after backward: 0.279665470123291
after ba