## Stochastic Answer Networks (SAN)

In [1]:
import re
from functools import partial
from itertools import tee
from collections import namedtuple
import os
from datetime import datetime

import pandas as pd
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm import tqdm_notebook

from transformers import BertTokenizer, BertModel

device = torch.device('cuda')

### Data Loading

In [2]:
class SANDataLoading(Dataset):
    def __init__(self, ids, passages, options, answers, pretrained_model='bert-base-cased'):
        
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model, 
                                                       do_basic_tokenize=True,
                                                       never_split=['<BLANK>'])
        
        self.MAX_SEQ_LEN = 512
        
        self.ids = ids
        self.passages = passages
        self.options = options
        self.answers = answers
        
    @classmethod    
    def from_tsv(cls, data_path, pretrained_model='bert-base-cased'):
        dataframe = pd.read_csv(data_path, sep='\t', index_col=0)
        
        ids = dataframe.index.tolist()
        passages = dataframe['본문'].tolist()
        options = list(zip(dataframe['보기1'],dataframe['보기2'],dataframe['보기3'],
                           dataframe['보기4'],dataframe['보기5']))
        answers = dataframe['정답'].tolist()
        return cls(ids, passages, options, answers, pretrained_model)
    
    def __len__(self):
        return len(self._ids)
    
    def __getitem__(self,idx):
        
        id_ = self._ids[idx]
        passage_input_id = self._passage_input_ids[idx]
        passage_token_type_id = self._passage_token_type_ids[idx]
        passage_attention_mask = self._passage_attention_masks[idx] 
        option_input_id = self._option_input_ids[idx]
        option_token_type_id = self._option_token_type_ids[idx]
        option_attention_mask = self._option_attention_masks[idx]
        answer = self._answers[idx]
        
        return id_, passage_input_id, passage_token_type_id, passage_attention_mask, \
    option_input_id, option_token_type_id, option_attention_mask, answer
    
    def _get_tokens(self, text, is_passage):
        
        first_end_idx = text.find('<BLANK>')
        last_start_idx = first_end_idx + len('<BLANK>')
        if is_passage:
            first = self.tokenizer.tokenize(text[:first_end_idx])
            last = self.tokenizer.tokenize(text[last_start_idx:])
            tokens = first + ['<BLANK>'] + last
        else:
            tokens = self.tokenizer.tokenize(text)
        return tokens
    
    def _add_special_token(self, tokens, is_passage):
        
        temp = None
        if is_passage:
            blank_idx = next(i for i, token in enumerate(tokens) if token =='<BLANK>')
        
            first = ['[CLS]'] + tokens[:blank_idx] + ['[SEP]']
            last = tokens[blank_idx+1:] + ['[SEP]']
            temp = first+last
            
        else:
            temp = ['[CLS]'] + tokens + ['[SEP]']
            
        pad = ['[PAD]'] * (self.MAX_SEQ_LEN - len(temp))
            
        return temp + pad
    
    def _get_input_ids(self, tokens):
        return self.tokenizer.convert_tokens_to_ids(tokens)
    
    def _get_token_type_ids(self, tokens):
        
        sep_idx = list(i for i, token in enumerate(tokens) if token=='[SEP]')
        
        token_type_ids = None
        
        if len(sep_idx) > 1: 
            token_type_ids = [0]*(sep_idx[0] + 1) + \
            [1]*(sep_idx[1] - sep_idx[0]) + \
            [0]*(self.MAX_SEQ_LEN - sep_idx[1] -1)
            
        else:
            token_type_ids = [0]*(sep_idx[0] + 1) + \
            [0]*(self.MAX_SEQ_LEN - sep_idx[0] -1)
            
        return token_type_ids
    
    def _get_attention_mask(self, tokens):
        
        pad_idx = next(i for i, token in enumerate(tokens) if token == '[PAD]')
        attention_mask = [1] * pad_idx + [0] * (self.MAX_SEQ_LEN - pad_idx)
        
        return attention_mask
    
    def _preprocess(self):
        
        self._ids = list()
        self._passage_input_ids = list()
        self._passage_token_type_ids = list()
        self._passage_attention_masks = list()
        
        self._option_input_ids = list()
        self._option_token_type_ids = list()
        self._option_attention_masks = list()
       
        self._answers = list()
        
        for id_, passage, options, answer in zip(self.ids, self.passages, self.options, self.answers):
            
            # id
            self._ids += [str(id_) + '_' + str(i+1) for i in range(len(options))]
            
            # passage & options
            passage_tokens = self._get_tokens(passage, True)
            options_tokens = list(self._get_tokens(option, False) for option in options)
            
            passage_tokens = self._add_special_token(passage_tokens, True)
            options_tokens = list(self._add_special_token(option_tokens, False) for option_tokens in options_tokens)
            
            passage_input_id = self._get_input_ids(passage_tokens)
            options_input_id = list(self._get_input_ids(option_tokens) for option_tokens in options_tokens)
            
            self._passage_input_ids += [passage_input_id] * len(options)
            self._option_input_ids += options_input_id
            
            passage_token_type_id = self._get_token_type_ids(passage_tokens)
            options_token_type_id = list(self._get_token_type_ids(option_tokens) for option_tokens in options_tokens)
            
            self._passage_token_type_ids += [passage_token_type_id] * len(options)
            self._option_token_type_ids += options_token_type_id
            
            passage_attention_mask = self._get_attention_mask(passage_tokens)
            options_attention_mask = list(self._get_attention_mask(option_tokens) for option_tokens in options_tokens)
            
            self._passage_attention_masks += [passage_attention_mask] * len(options)
            self._option_attention_masks += options_attention_mask
            
            # answer
            tmp = [1] * len(options)
            tmp[int(answer-1)] = 0
            self._answers += tmp
            
            
    def _collate(self, batch, device):
        
        ids = list()
        passage_input_ids = list()
        passage_token_type_ids = list()
        passage_attention_masks = list()
        option_input_ids = list()
        option_token_type_ids = list()
        option_attention_masks = list()
        answers = list()
        
        for id_, passage_input_id, passage_token_type_id, passage_attention_mask, \
        option_input_id, option_token_type_id, option_attention_mask, answer in batch:
            ids.append(id_)
            
            passage_input_ids.append(torch.LongTensor(passage_input_id).to(device))
            passage_token_type_ids.append(torch.LongTensor(passage_token_type_id).to(device))
            passage_attention_masks.append(torch.LongTensor(passage_attention_mask).to(device))
            
            option_input_ids.append(torch.LongTensor(option_input_id).to(device))
            option_token_type_ids.append(torch.LongTensor(option_token_type_id).to(device))
            option_attention_masks.append(torch.LongTensor(option_attention_mask).to(device))
            answers.append(answer)
                
        passage_input_ids = torch.stack(passage_input_ids, dim=0)
        passage_token_type_ids = torch.stack(passage_token_type_ids, dim=0)
        passage_attention_masks = torch.stack(passage_attention_masks, dim=0)
        
        option_input_ids = torch.stack(option_input_ids, dim=0)
        option_token_type_ids = torch.stack(option_token_type_ids, dim=0)
        option_attention_masks = torch.stack(option_attention_masks, dim=0)
        answers = torch.LongTensor(answers).to(device)
        
        return id_, passage_input_ids, passage_token_type_ids, passage_attention_masks, \
    option_input_ids, option_token_type_ids, option_attention_masks, answers
    
    def __call__(self, batch_size, do_split, train_ratio, num_workers, device):
        
        self._preprocess()
        
        result = None
        if do_split:
            temp_train, temp_test = tee(torch.bernoulli(train_ratio * torch.ones(len(self.ids))), 2)
            temp_train = list(i for i, x in enumerate(temp_train) if x.item() == 1)
            temp_test = list(i for i, x in enumerate(temp_test) if x.item() == 0)
            
            indices_train = list()
            for t in temp_train:
                indices_train += list(range(5*t, 5*(t+1)))
                
            indices_test = list()
            for t in temp_test:
                indices_test += list(range(5*t, 5*(t+1)))
            
            subset_train = Subset(self, indices_train)
            subset_test = Subset(self, indices_test)
            
            train_loader = DataLoader(subset_train, batch_size=batch_size, num_workers=num_workers, 
                          collate_fn=partial(self._collate, device=device))
            test_loader = DataLoader(subset_test, batch_size=batch_size, num_workers=num_workers, 
                          collate_fn=partial(self._collate, device=device))
            result = (train_loader, test_loader)
            
        else:
            data_loader = DataLoader(self, batch_size=batch_size, num_workers=num_workers, 
                          collate_fn=partial(self._collate, device=device))
            result = data_loader
            
        return result 

### Model

In [3]:
class BertWithSan(nn.Module):
    def __init__(self, pretrained_model='bert-base-cased', hidden_size=768, class_size=2, dropout_prob=0.4, K=7):
        super().__init__()
        
        self.K = K
        
        self.bert_p = BertModel.from_pretrained(pretrained_model) # BERT: Passage encoder
        self.bert_h = BertModel.from_pretrained(pretrained_model) # BERT: Option encoder
        
        self.w_1 = nn.Parameter(torch.rand((hidden_size,))) # M^h to alpha
        self.W_2 = nn.Parameter(torch.rand((hidden_size, 512))) # M^p to beta 
        self.W_3 = nn.Parameter(torch.rand((1, 4))) # hidden states to P_r
        self.grucell = nn.GRUCell(hidden_size, hidden_size)
        
        self.dropout = nn.Dropout(dropout_prob) 
        self.classifier = nn.Linear(hidden_size, 2)
        
        
    def forward(self, passage_input_ids, passage_token_type_ids, passage_attention_masks, 
                option_input_ids, option_token_type_ids, option_attention_masks):
        
        # passage input ids shape: (batch_size, max_position_embeddings)
        # passage token type ids shape: (batch_size, max_position_embeddings)
        # option input ids shape: (batch_size, max_position_embeddings)
        # option token type ids shape: (batch_size, max_position_embeddings)
        
        M_p, _ = self.bert_p(passage_input_ids, passage_token_type_ids, passage_attention_masks, False)
        M_h, _ = self.bert_h(option_input_ids, option_token_type_ids, option_attention_masks, False)
        # M_p shape: (batch_size, max_position_embeddings, hidden_size)
        # M_h shape: (batch_size, max_position_embeddings, hidden_size)
        
        alpha = F.softmax(self.w_1*M_h).squeeze(2)
        # w_1 shape: (1, hidden_size)
        # alpha shape: (batch_size, max_position_embeddings)
        
        s_0 = torch.sum(alpha*M_h, dim=1)
        # s_0 shape: (batch_size, hidden_size)
        
        temp = list()
        for k in range(self.K):
            
            s_k_1 = None
            # s^{k-1}
            if k == 0:
                s_k_1 = s_0
            else:
                s_k_1 = s_k
            
            beta = F.softmax((torch.matmul(s_k_1, self.W_2).unsqueeze(1)@M_p).squeeze(1))
            # beta shape: (batch_size, hidden_size)
            
            x_k = torch.sum((beta.unsqueeze(1)*M_p), dim=1)
            # x_k shape: (batch_size, hidden_size)
            
            s_k = self.grucell(x_k, s_k_1)
            # s_k shape: (batch_size, hidden_size)
            
            P_r_k = F.softmax(torch.matmul(self.W_3,
                                           torch.stack([s_k, x_k, torch.abs(s_k-x_k), (s_k*x_k)], dim=1)))
            # P_r_k shape: (batch_size, 1, hidden_size)
            
            temp.append(P_r_k)
        
        P_r = torch.cat(temp, dim=1)
        P_r = self.dropout(P_r)
        P_r = torch.mean(P_r, dim=1)
        # P_r shape: (batch_size, hidden_size)
        
        logits = self.classifier(P_r)
        # logits shape: (batch_size, 2)
        
        return logits

### train / evaluate

In [4]:
def train(model, data_parallel, data_loader, optimizer, criterion):
    
    model.train()
    if data_parallel: # use Data Parallelism with Multi-GPU
        model = nn.DataParallel(model)
    epoch_loss = 0
    iter_bar = tqdm_notebook(data_loader, desc='Iter (loss=X.XXX)')
    
    for i, batch in enumerate(iter_bar):
        optimizer.zero_grad()
        
        passage_input_ids, passage_token_type_ids, passage_attention_masks, \
                option_input_ids, option_token_type_ids, option_attention_masks = batch[1:-1]
        label = batch[-1]

        logits = model(passage_input_ids, passage_token_type_ids, passage_attention_masks,
                           option_input_ids, option_token_type_ids, option_attention_masks)
        
        loss = criterion(logits, label)
        loss.backward()
        
        optimizer.step()
        epoch_loss += loss.item()
        iter_bar.set_description('Iter (loss={})'.format(loss.item()))
        
    return epoch_loss / len(data_loader)


def evaluate(model, data_loader, criterion):
    
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            
            passage_input_ids, passage_token_type_ids, passage_attention_masks, \
                option_input_ids, option_token_type_ids, option_attention_masks = batch[1:-1]
            label = batch[-1]
            
            logits = model(passage_input_ids, passage_token_type_ids, passage_attention_masks,
                           option_input_ids, option_token_type_ids, option_attention_masks)
            
            loss = criterion(logits, label)
            epoch_loss += loss.item()
            
    return epoch_loss / len(data_loader)


def score(model, load_path, data_loader, device):
    
    # load model
    if load_path is not None:
        model.load_state_dict(torch.load(load_path))
        print('Loading the model from', load_path)
    
    model.eval()
    match = 0
    N = 0
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            
            passage_input_ids, passage_token_type_ids, passage_attention_masks, \
                option_input_ids, option_token_type_ids, option_attention_masks = batch[1:-1]
            label = batch[-1]
            
            true = torch.argmin(label)
            
            logits = model(passage_input_ids, passage_token_type_ids, passage_attention_masks,
                           option_input_ids, option_token_type_ids, option_attention_masks)
            
            pred = torch.argmax(logits.transpose(0,1)[0])
            
            if true.item() == pred.item():
                match += 1
                
            N += 1
            
    score = match / N
    
    return score


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, save_path, patience=1, verbose=False):
        """
        Args:
            patience (int): How long to wait agter last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improved.
                            Default: False
        """
        self.save_path = save_path
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        
    def __call__(self, val_loss, model):
        
        score = val_loss
        
        if self.best_score is None:
            torch.save(model.state_dict(), self.save_path)
            print("Saving the model to", self.save_path)
            self.best_score = score
            
        elif score > self.best_score:
            self.counter += 1 
            print('EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True
                
        else:
            torch.save(model.state_dict(), self.save_path)
            print("Saving the model to", self.save_path)
            self.best_score = score
            self.counter = 0
            
        return self.early_stop

### main

In [12]:
data_loading_train = SANDataLoading.from_tsv('KSAT_TRAIN.tsv')
train_loader = data_loading_train(batch_size=5, do_split=False, train_ratio=0.7, num_workers=0, device=device)

The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.


In [13]:
data_loading_test = SANDataLoading.from_tsv('KSAT_TEST.tsv')
test_loader = data_loading_test(batch_size=5, do_split=False, train_ratio=0.7, num_workers=0, device=device)

The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.


In [6]:
model = BertWithSan().to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss(torch.FloatTensor([5,1.25]).to(device))

In [8]:
now = datetime.now()
year = str(now.year)[-2:]
month = '{00}'.format(now.month)
day = '{00}'.format(now.day)
hour = '{00}'.format(now.hour)
min_ = '{00}'.format(now.minute)

model_name = model.__class__.__name__ +'_'+year+month+day+hour+min_
# PATH = os.path.join('./',model_name+'_'epoch+'.pt')

In [9]:
N_EPOCHS = 50

# SAVE_PATH = PATH
# print('SAVE PATH: {}'.format(SAVE_PATH))

patience = 50
verbose = True

# early_stopping = EarlyStopping(SAVE_PATH, patience, verbose)

for epoch in range(N_EPOCHS):
    
    train_loss = train(model, False, train_loader, optimizer, criterion)
#     test_loss = evaluate(model, test_loader, criterion)
    
#     early_stop = early_stopping(test_loss, model)
#     if early_stop:
#         print('Epoch {}| Train Loss : {}, Test Loss : {}'.format(epoch+1, train_loss, test_loss))
#         break
    if epoch % 5 == 0:
        torch.save(model.state_dict(), os.path.join('./',model_name+'_'+str(epoch)+'.pt'))
        print("Saving the model to", os.path.join('./',model_name+'_'+str(epoch)+'.pt'))
    print('Epoch {}| Train Loss : {}'.format(epoch+1, train_loss))
    
torch.save(model.state_dict(), os.path.join('./',model_name+'.pt'))
print("Saving the model to", os.path.join('./',model_name+'.pt'))

HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…




Saving the model to ./BertWithSan_196271656_0.pt
Epoch 1| Train Loss : 0.6746212026515565


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 2| Train Loss : 0.5951903940795304


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 3| Train Loss : 0.4711210793682507


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 4| Train Loss : 0.5138380392998844


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 5| Train Loss : 0.42781256856275846


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Saving the model to ./BertWithSan_196271656_5.pt
Epoch 6| Train Loss : 0.6075884458887113


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 7| Train Loss : 0.46222117078768743


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 8| Train Loss : 0.637736383390117


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 9| Train Loss : 0.7025562526343705


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 10| Train Loss : 0.6971912843066377


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Saving the model to ./BertWithSan_196271656_10.pt
Epoch 11| Train Loss : 0.6961713261805571


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 12| Train Loss : 0.6819497602713572


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 13| Train Loss : 0.6951787794178182


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 14| Train Loss : 0.6978334571634021


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 15| Train Loss : 0.6983771541675964


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Saving the model to ./BertWithSan_196271656_15.pt
Epoch 16| Train Loss : 0.7015916913360745


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 17| Train Loss : 0.6953795683074308


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 18| Train Loss : 0.696494239646119


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 19| Train Loss : 0.6964538423271922


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 20| Train Loss : 0.6949351021221706


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Saving the model to ./BertWithSan_196271656_20.pt
Epoch 21| Train Loss : 0.6939907513655625


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 22| Train Loss : 0.6936609835593731


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 23| Train Loss : 0.694457460301263


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 24| Train Loss : 0.6943544382398779


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 25| Train Loss : 0.6929478311693513


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Saving the model to ./BertWithSan_196271656_25.pt
Epoch 26| Train Loss : 0.6937886233453626


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 27| Train Loss : 0.6945131372321736


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 28| Train Loss : 0.69403442252766


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 29| Train Loss : 0.6954437900673259


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 30| Train Loss : 0.6949483992217423


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Saving the model to ./BertWithSan_196271656_30.pt
Epoch 31| Train Loss : 0.6946029259786978


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 32| Train Loss : 0.694708419775034


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 33| Train Loss : 0.6911649958653884


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 34| Train Loss : 0.6989743505979513


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 35| Train Loss : 0.6915545255332798


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Saving the model to ./BertWithSan_196271656_35.pt
Epoch 36| Train Loss : 0.6964405599352601


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 37| Train Loss : 0.6927487308328802


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 38| Train Loss : 0.693624522159626


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 39| Train Loss : 0.694505946744572


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 40| Train Loss : 0.6968522061001171


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Saving the model to ./BertWithSan_196271656_40.pt
Epoch 41| Train Loss : 0.6941272448409688


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 42| Train Loss : 0.6954089463531197


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 43| Train Loss : 0.6948294290474483


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 44| Train Loss : 0.6932845372658272


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 45| Train Loss : 0.6937738368263492


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Saving the model to ./BertWithSan_196271656_45.pt
Epoch 46| Train Loss : 0.6933478099959237


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 47| Train Loss : 0.6947168497296122


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 48| Train Loss : 0.6939676208929582


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 49| Train Loss : 0.6934667763771949


HBox(children=(IntProgress(value=0, description='Iter (loss=X.XXX)', max=770, style=ProgressStyle(description_…


Epoch 50| Train Loss : 0.6956129141442188
Saving the model to ./BertWithSan_196271656.pt


### predict

In [10]:
# LOAD_PATH = PATH

train_score = score(model, None, train_loader, device)
test_score = score(model, None, test_loader, device)
print('Train | {}\nTest | {}'.format(train_score, test_score))



Train | 0.19220779220779222
Test | 0.06315789473684211


#### Final models

In [16]:
for file in sorted(filter(lambda x: x.startswith('BertWithSan_196271656'), os.listdir('./'))):
    path = os.path.join('./', file)
    print('FileName: {}'.format(file))
    train_score = score(model, path, train_loader, device)
    test_score = score(model, path, test_loader, device)
    print('Train | {}\nTest | {}'.format(train_score, test_score))

FileName: BertWithSan_196271656.pt
Loading the model from ./BertWithSan_196271656.pt




Loading the model from ./BertWithSan_196271656.pt
Train | 0.19220779220779222
Test | 0.06315789473684211
FileName: BertWithSan_196271656_0.pt
Loading the model from ./BertWithSan_196271656_0.pt
Loading the model from ./BertWithSan_196271656_0.pt
Train | 0.34675324675324676
Test | 0.35789473684210527
FileName: BertWithSan_196271656_10.pt
Loading the model from ./BertWithSan_196271656_10.pt
Loading the model from ./BertWithSan_196271656_10.pt
Train | 0.6948051948051948
Test | 0.29473684210526313
FileName: BertWithSan_196271656_15.pt
Loading the model from ./BertWithSan_196271656_15.pt
Loading the model from ./BertWithSan_196271656_15.pt
Train | 0.6441558441558441
Test | 0.3684210526315789
FileName: BertWithSan_196271656_20.pt
Loading the model from ./BertWithSan_196271656_20.pt
Loading the model from ./BertWithSan_196271656_20.pt
Train | 0.44935064935064933
Test | 0.17894736842105263
FileName: BertWithSan_196271656_25.pt
Loading the model from ./BertWithSan_196271656_25.pt
Loading the mo