In [None]:
!pip install transformers
!pip install torchmetrics
!pip install pytorch_lightning

# Section 1

In [None]:
from google.colab import drive
from torch import nn
from transformers import BertTokenizer, BertModel
import torch
import csv
from torch.utils.data import Dataset
from transformers import BertTokenizer, BertModel, BertForMaskedLM, AdamW
from torchmetrics import Accuracy
import os
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils.rnn import pad_sequence
import tqdm
import time
import math
import numpy as np
from transformers import DataCollatorForLanguageModeling
import pytorch_lightning as pl

In [None]:
drive.mount('/content/drive')

data_path = '/content/drive/MyDrive/data'

train_data_path = f'{data_path}/train'
true_train_path = f'{train_data_path}/true.csv'
false_train_path = f'{train_data_path}/false.csv'

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

true_label = 1
false_label = 0

true_saved_model_path = '/content/drive/MyDrive/models/true_bert_lm.pt'
false_saved_model_path = '/content/drive/MyDrive/models/false_bert_lm.pt'

In [None]:
class LMSpellCheckingDataset(Dataset):

    def __init__(self, tokenizer, data_paths, labels, batch_size=32):
        self.dataset = []
        for i in range(len(data_paths)):
          data_path = data_paths[i]
          with open(data_path, 'r', encoding='utf-8') as file:
              data = csv.reader(file)
              for item in data:
                self.dataset.append((item[0], labels[i]))
        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def __getitem__(self, idx):
        if (self.tokenizer == None):
          raise Exception('Tokenizer cannot be null')

        tweet, label = self.dataset[idx]
        tokenized_tweet = self.tokenizer(tweet)
        input_ids = tokenized_tweet['input_ids']
        attention_mask = tokenized_tweet['attention_mask']

        return torch.tensor(input_ids)

    def __len__(self):
      return len(self.dataset)

class BertLM(pl.LightningModule):

    def __init__(self, class_name):
        super().__init__()
        self.bert = BertForMaskedLM.from_pretrained('bert-base-uncased')
        self.epoch_number = 0
        self.class_name = class_name

    def forward(self, input_ids, labels):
        return self.bert(input_ids=input_ids,labels=labels)

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        labels = batch["labels"]
        outputs = self(input_ids=input_ids, labels=labels)
        loss = outputs[0]
        return {"loss": loss}

    def training_epoch_end(self, outputs):
        super().training_epoch_end(outputs)
        mean_loss = 0
        n_batch  = len(outputs)
        for i in range(n_batch):
            mean_loss += outputs[i]['loss'].cpu().numpy() / n_batch
        print(f"End of epoch {self.epoch_number} with mean loss '{mean_loss}' on label {self.class_name}.", "fine_tuning")
        self.epoch_number += 1
    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=1e-5)

class BertLMPred(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertForMaskedLM.from_pretrained('bert-base-uncased')

    def forward(self, input_ids, labels=None):
        return self.bert(input_ids=input_ids,labels=labels)

In [None]:
MASK = '[MASK]'
SEP = '[SEP]'
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize_batch(batch, tokenizer):
    return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]

def get_init_text(seed_sentence, max_len, tokenizer, batch_size=1):
    batch = [seed_sentence + [MASK] * max_len + [SEP] for _ in range(batch_size)]
    return tokenize_batch(batch, tokenizer)

def untokenize_batch(batch, tokenizer):
    return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]

def generate_step(out, gen_idx, temperature=None, top_k=0, 
                  sample=False, return_list=True):
    """ Generate a word from from out[gen_idx]
    
    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k 
    """
    logits = out.logits[:, gen_idx]
    if temperature is not None:
        logits = logits / temperature
    if top_k > 0:
        kth_vals, kth_idx = logits.topk(top_k, dim=-1)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1)
    elif sample:
        dist = torch.distributions.categorical.Categorical(logits=logits)
        idx = dist.sample().squeeze(-1)
    else:
        idx = torch.argmax(logits, dim=-1)
    return idx.tolist() if return_list else idx


def parallel_sequential_generation(seed_text, tokenizer, model, batch_size=1, 
                                   max_len=15, top_k=0, temperature=None, 
                                   max_iter=300, burnin=200, cuda=False, 
                                   print_every=10, verbose=True):
    """ Generate for one random position at a timestep
    
    args:
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax
    """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, tokenizer, batch_size=batch_size)
    mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]
    for ii in range(max_iter):
        kk = np.random.randint(0, max_len)
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = mask_id
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        topk = top_k if (ii >= burnin) else 0
        idxs = generate_step(out, gen_idx=seed_len+kk, top_k=topk, temperature=temperature, sample=(ii < burnin))
        #if idxs is a single number
        if isinstance(idxs, int):
            idxs = [idxs]
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = idxs[jj]
            
    return untokenize_batch(batch, tokenizer)


def generate(tokenizer, n_samples, class_name, model, seed_text="[CLS]", 
             batch_size=1, max_len=25, sample=True, top_k=100, temperature=1.0, 
             burnin=200, max_iter=500, cuda=False, print_every=1):
    # main generation function to call
    sentences = []
    n_batches = math.ceil(n_samples / batch_size)
    start_time = time.time()
    for batch_n in range(n_batches):
        batch = parallel_sequential_generation(seed_text, tokenizer, model, 
                                               max_len=max_len, top_k=top_k,
                                               temperature=temperature, 
                                               burnin=burnin, max_iter=max_iter, 
                                               cuda=cuda, verbose=False)
        
        if (batch_n + 1) % print_every == 0:
            print("Finished batch %d in %.3fs" % (batch_n + 1, time.time() - start_time))
            print("Finished batch %d in %.3fs" % (batch_n + 1, time.time() - start_time), "fine_tuning")
            start_time = time.time()
        
        sentences += batch
    return sentences

def standardize_sentence(sent):
    sentence = []
    current_word = sent[0]
    for i in range(1, len(sent)):
        token = sent[i]
        if(token[0:2] == '##'):
            current_word += token[2:]
        else:
            sentence.append(current_word)
            current_word = token
    sentence.append(current_word)
    return sentence

def fine_tune_LM(class_name, tokenizer, epochs, batch_size, 
                 train_paths, labels, save_url=None, mlm_prob=0.25, 
                 use_gpu=True):
  
    dataset = LMSpellCheckingDataset(tokenizer, train_paths, labels)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, 
                                                    mlm=True, 
                                                    mlm_probability=mlm_prob)
    train_loader = DataLoader(dataset, 
                              batch_size=batch_size, 
                              collate_fn=data_collator)
    model = BertLM(class_name)
    #using CPU
    if use_gpu:
        trainer = pl.Trainer(max_epochs=epochs, 
                             checkpoint_callback=False, 
                             logger=False, 
                             gpus=1)
    else:
        trainer = pl.Trainer(max_epochs=epochs, 
                             checkpoint_callback=False, 
                             logger=False)
        
    print(f"Start fine tuning BERT masked LM on class {class_name}", "fine_tuning")
    trainer.fit(model, train_loader)
    if save_url is not None:
        print(f"Finished training. Saving model in {save_url}", "fine_tuning")
        torch.save(model.state_dict(), save_url)

def fine_tune(tokenizer):
    labels = ['false', 'true']
    bert_model_url = "/content/drive/MyDrive/models/"
    reports_dir = "/content/drive/MyDrive/reports/"
    if not os.path.exists(reports_dir):
        os.makedirs(reports_dir, exist_ok=True)
    labels = [item for item in labels]
    
    for label in labels:
        print('----------\n')
        print(f"Generating sentences for {label} label ...", "fine_tuning")
        model_url = os.path.join(bert_model_url, f"{label}_bert_lm.pt")
        model = BertLMPred()
        model.load_state_dict(torch.load(model_url))
        model = model.cuda()
        sentences = generate(tokenizer, 10, label, model, max_len=10, seed_text='[CLS]'.split(), cuda=True)
        out_url = reports_dir + f"{label}.txt"
        with open(out_url, 'w') as outf:
            for sent in sentences:
                sent = standardize_sentence(sent)
                outf.write(' '.join(sent).replace('[', '<').replace(']', '>'))
                outf.write("\\\\")
                outf.write("\n")
        print(f"Sentences for {label} saved to {out_url}.", "fine_tuning")
        print(sentences)
        print('----------\n')
        
fine_tune_LM('true', tokenizer, 20, 32, [true_train_path], [true_label], 
             save_url=true_saved_model_path)

fine_tune_LM('false', tokenizer, 20, 32, [false_train_path], [false_label], 
             save_url=false_saved_model_path)

fine_tune(tokenizer)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Start fine tuning BERT masked LM on class true fine_tuning


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type            | Params
-----------------------------------------
0 | bert | BertForMaskedLM | 109 M 
-----------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
438.057   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Finished training. Saving model in /content/drive/MyDrive/models/true_bert_lm.pt fine_tuning


# Section 2

In [None]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
from google.colab import drive
drive.mount('/content/drive')

data_path = '/content/drive/MyDrive/University/NLP/Final-Project/data'
# data_path = '/content/drive/MyDrive/data'

train_data_path = f'{data_path}/train'
true_train_path = f'{train_data_path}/true.csv'
false_train_path = f'{train_data_path}/false.csv'

test_data_path = f'{data_path}/test'
true_test_path = f'{test_data_path}/true.csv'
false_test_path = f'{test_data_path}/false.csv'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Pytorch's nn module has lots of useful feature
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from nltk.tokenize import sent_tokenize, word_tokenize
from torch.nn.utils.rnn import pad_sequence
import csv

def pad_batched_sequence(batch):
    tweets = []
    tweets_lengths = []
    labels = []
    for (tweet, label) in batch:
      tweets.append(torch.tensor(tweet).cuda())
      tweets_lengths.append(len(tweet))
      labels.append(label)

    tweets = pad_sequence(tweets, padding_value=0, batch_first=True).cuda()
    tweets_lengths = torch.tensor(tweets_lengths).cuda()
    labels = torch.tensor(labels).cuda()
    return tweets, tweets_lengths, labels

word_to_idx = {
  '<pad>': 0,
  '<start>': 1,
  '<stop>': 2
}
class SpellCheckingDataset(Dataset):

    def __init__(self, data_paths, labels, batch_size=32):
        self.dataset = []
        
        idx = 3
        for i in range(len(data_paths)):
          data_path = data_paths[i]
          with open(data_path, 'r', encoding='utf-8') as file:
              data = csv.reader(file)
              for item in data:
                tokenized_tweet = ['<start>'] + word_tokenize(item[0]) + ['<stop>']
                for word in tokenized_tweet:
                  if (word not in word_to_idx):
                    word_to_idx[word] = idx
                    idx += 1
                self.dataset.append((tokenized_tweet, labels[i]))
        self.batch_size = batch_size
        self.vocab_size = len(word_to_idx)

    def __getitem__(self, idx):
        return [word_to_idx[w] for w in self.dataset[idx][0]], self.dataset[idx][1]

    def __len__(self):
      return len(self.dataset)

class LSTMNet(nn.Module):
    
    def __init__(self,vocab_size,embedding_dim,hidden_dim,output_dim,n_layers,bidirectional,dropout):
        
        super(LSTMNet,self).__init__()
        
        # Embedding layer converts integer sequences to vector sequences
        self.embedding = nn.Embedding(vocab_size,embedding_dim)
        
        # LSTM layer process the vector sequences 
        self.lstm = nn.LSTM(embedding_dim,
                            hidden_dim,
                            num_layers = n_layers,
                            bidirectional = bidirectional,
                            dropout = dropout,
                            batch_first = True
                           )
        
        # Dense layer to predict 
        self.fc = nn.Linear(hidden_dim * 2,output_dim)
        # Prediction activation function
        self.sigmoid = nn.Sigmoid()
        
    
    def forward(self,text,text_lengths):
        embedded = self.embedding(text)
        
        # Thanks to packing, LSTM don't see padding tokens 
        # and this makes our model better
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.cpu(),batch_first=True, enforce_sorted=False)
        
        packed_output,(hidden_state,cell_state) = self.lstm(packed_embedded)
        
        # Concatenating the final forward and backward hidden states
        hidden = torch.cat((hidden_state[-2,:,:], hidden_state[-1,:,:]), dim = 1)
        
        dense_outputs=self.fc(hidden)

        #Final activation function
        outputs=self.sigmoid(dense_outputs)
        
        return outputs
    

In [None]:

true_label = 1
false_label = 0

train_dataset = SpellCheckingDataset([true_train_path, false_train_path], 
                                          [true_label, false_label])

test_dataset = SpellCheckingDataset([true_test_path, false_test_path], 
                                          [true_label, false_label])

SIZE_OF_VOCAB = train_dataset.vocab_size
EMBEDDING_DIM = 100
NUM_HIDDEN_NODES = 64
NUM_OUTPUT_NODES = 1
NUM_LAYERS = 2
BIDIRECTION = True
DROPOUT = 0.2

In [None]:
model = LSTMNet(SIZE_OF_VOCAB,
                EMBEDDING_DIM,
                NUM_HIDDEN_NODES,
                NUM_OUTPUT_NODES,
                NUM_LAYERS,
                BIDIRECTION,
                DROPOUT
               )



In [None]:
import torch.optim as optim
model = model.cuda()
optimizer = optim.Adam(model.parameters(),lr=1e-4)
criterion = nn.BCELoss()
criterion = criterion.cuda()

In [None]:
def binary_accuracy(preds, y):
    #round predictions to the closest integer
    rounded_preds = torch.round(preds)
    
    correct = (rounded_preds == y).float() 
    acc = correct.sum() / len(correct)
    return acc

In [None]:
def train(model,iterator,optimizer,criterion):
    
    epoch_loss = 0.0
    epoch_acc = 0.0
    
    model.train()
    
    for batch in iterator:
        
        # cleaning the cache of optimizer
        optimizer.zero_grad()
        
        text,text_lengths, labels = batch
        
        # forward propagation and squeezing
        predictions = model(text,text_lengths).squeeze()
        
        # computing loss / backward propagation
        loss = criterion(predictions.double(),labels.double())
        loss.backward()
        
        # accuracy
        acc = binary_accuracy(predictions,labels)
        
        # updating params
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    # It'll return the means of loss and accuracy
    return epoch_loss / len(iterator), epoch_acc / len(iterator)
        

In [None]:
def evaluate(model,iterator,criterion):
    
    epoch_loss = 0.0
    epoch_acc = 0.0
    
    # deactivate the dropouts
    model.eval()
    
    # Sets require_grad flat False
    with torch.no_grad():
        for batch in iterator:
            text,text_lengths, labels = batch
            
            predictions = model(text,text_lengths).squeeze()
              
            #compute loss and accuracy
            loss = criterion(predictions.double(), labels.double())
            acc = binary_accuracy(predictions, labels)
            
            #keep track of loss and accuracy
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
from torch.utils.data.dataloader import DataLoader

EPOCH_NUMBER = 15
train_loader = DataLoader(train_dataset, 
                          batch_size=32, 
                          shuffle=True,
                          drop_last=True, 
                          collate_fn=pad_batched_sequence)

test_loader = DataLoader(test_dataset, 
                          batch_size=32, 
                          shuffle=True,
                          drop_last=True, 
                          collate_fn=pad_batched_sequence)

for epoch in range(1,EPOCH_NUMBER+1):
    
    train_loss,train_acc = train(model,train_loader,optimizer,criterion)
    
    valid_loss,valid_acc = evaluate(model,test_loader,criterion)
    
    # Showing statistics
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')
    print()

	Train Loss: 0.692 | Train Acc: 52.11%
	 Val. Loss: 0.692 |  Val. Acc: 51.88%

	Train Loss: 0.669 | Train Acc: 58.39%
	 Val. Loss: 0.698 |  Val. Acc: 53.89%

	Train Loss: 0.632 | Train Acc: 63.48%
	 Val. Loss: 0.698 |  Val. Acc: 55.10%

	Train Loss: 0.589 | Train Acc: 68.44%
	 Val. Loss: 0.714 |  Val. Acc: 56.64%

	Train Loss: 0.542 | Train Acc: 72.23%
	 Val. Loss: 0.740 |  Val. Acc: 57.96%

	Train Loss: 0.497 | Train Acc: 75.29%
	 Val. Loss: 0.743 |  Val. Acc: 59.78%

	Train Loss: 0.457 | Train Acc: 78.12%
	 Val. Loss: 0.779 |  Val. Acc: 59.41%

	Train Loss: 0.414 | Train Acc: 80.82%
	 Val. Loss: 0.835 |  Val. Acc: 59.97%

	Train Loss: 0.380 | Train Acc: 82.61%
	 Val. Loss: 0.869 |  Val. Acc: 60.98%

	Train Loss: 0.344 | Train Acc: 84.58%
	 Val. Loss: 0.956 |  Val. Acc: 60.26%

	Train Loss: 0.310 | Train Acc: 86.54%
	 Val. Loss: 0.977 |  Val. Acc: 62.02%

	Train Loss: 0.278 | Train Acc: 88.12%
	 Val. Loss: 1.016 |  Val. Acc: 61.99%

	Train Loss: 0.249 | Train Acc: 89.44%
	 Val. Loss: 

# Section 3

In [None]:
from google.colab import drive
drive.mount('/content/drive')

data_path = '/content/drive/MyDrive/data'

train_data_path = f'{data_path}/train'
true_train_path = f'{train_data_path}/true.csv'
false_train_path = f'{train_data_path}/false.csv'

test_data_path = f'{data_path}/test'
true_test_path = f'{test_data_path}/true.csv'
false_test_path = f'{test_data_path}/false.csv'

Mounted at /content/drive


In [None]:


class TextClassificationModel(nn.Module):

    def __init__(self, embed_dim=768, num_class=1):
        super(TextClassificationModel, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased").cuda()
        self.fc = nn.Linear(embed_dim, num_class).cuda()
        self.sigmoid = nn.Sigmoid().cuda()

    def forward(self, input_ids, attention_masks):
        inputs = {
            "input_ids": input_ids.cuda(), 
            "attention_mask" : attention_masks.cuda()
        }
        outputs = self.bert(**inputs)
        output = outputs.last_hidden_state
        output = output[:,0,:]
        output = self.fc(output)
        return self.sigmoid(output)


In [None]:
class SpellCheckingDataset(Dataset):

    def __init__(self, tokenizer, data_paths, labels, batch_size=32):
        self.dataset = []
        for i in range(len(data_paths)):
          data_path = data_paths[i]
          with open(data_path, 'r', encoding='utf-8') as file:
              data = csv.reader(file)
              for item in data:
                self.dataset.append((item[0], labels[i]))
        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def __getitem__(self, idx):
        if (self.tokenizer == None):
          raise Exception('Tokenizer cannot be null')

        tweet, label = self.dataset[idx]
        tokenized_tweet = self.tokenizer(tweet)
        input_ids = tokenized_tweet['input_ids']
        attention_mask = tokenized_tweet['attention_mask']

        return input_ids, attention_mask, label

    def __len__(self):
      return len(self.dataset)

class LMSpellCheckingDataset(Dataset):

    def __init__(self, tokenizer, data_paths, labels, batch_size=32):
        self.dataset = []
        for i in range(len(data_paths)):
          data_path = data_paths[i]
          with open(data_path, 'r', encoding='utf-8') as file:
              data = csv.reader(file)
              for item in data:
                self.dataset.append((item[0], labels[i]))
        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def __getitem__(self, idx):
        if (self.tokenizer == None):
          raise Exception('Tokenizer cannot be null')

        tweet, label = self.dataset[idx]
        tokenized_tweet = self.tokenizer(tweet)
        input_ids = tokenized_tweet['input_ids']
        attention_mask = tokenized_tweet['attention_mask']

        return torch.tensor(input_ids)

    def __len__(self):
      return len(self.dataset)

In [None]:

def pad_batched_sequence(batch):
    
    input_ids = [torch.tensor(item[0]) for item in batch]

    attention_masks =  [torch.tensor(item[1]) for item in batch]

    input_ids = pad_sequence(input_ids, padding_value=0, batch_first=True)

    attention_masks = pad_sequence(attention_masks, 
                                   padding_value=0, 
                                   batch_first=True)

    labels = None
    
    if batch[0][2] is not None:
        labels = torch.tensor([[item[2]] for item in batch]).double().cuda()
    
    return input_ids.cuda(), attention_masks.cuda(), labels

class SpellCheckingTrainer():

  def __init__(self,
               model,
               train_dataset,
               save_data_path,
               batch_size=32,
               epochs=20,
               lr=0.001):
        
    self.model = model
    self.epochs = epochs
    self.batch_size = batch_size
    self.train_loader = DataLoader(train_dataset, 
                                   batch_size=batch_size, 
                                   shuffle=True, 
                                   drop_last=True, 
                                   collate_fn=pad_batched_sequence)
    self.save_path = save_data_path
    self.loss_function = nn.BCELoss()
    self.optimizer = torch.optim.Adam(list(self.model.parameters()), lr=lr)        
    self.accuracy = Accuracy(num_classes=1)

  def train_one_epoch(self, epoch_index):
    running_loss = 0.
    running_accuracy = 0.
    last_loss = 0.
    threshold = torch.tensor([0.5]).cuda()

    for i, data in tqdm.tqdm(enumerate(self.train_loader), 
                             total=len(self.train_loader)):
        # Every data instance is an input + label pair
        input_ids, attention_masks, labels = data

        # Zero your gradients for every batch!
        self.optimizer.zero_grad()

        # Make predictions for this batch
        outputs = self.model(input_ids, attention_masks)

        # Compute the loss and its gradients
        loss = self.loss_function(outputs.double(), labels.double())
        loss.backward()

        # Adjust learning weights
        self.optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        
        result = (outputs > threshold).float() * 1
        running_accuracy += torch.sum(result == labels) / self.batch_size
        
        if i % 10 == 9:
            last_loss = running_loss / 10 # loss per batch
            last_accuracy = running_accuracy / 10 # accuracy per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            print('  batch {} accuracy: {}'.format(i + 1, last_accuracy))
            running_loss = 0.
            running_accuracy = 0.

    return last_loss


  def train(self):
    epoch_number = 0
    for epoch in range(self.epochs):
        print('EPOCH {}:'.format(epoch_number + 1))

        # Make sure gradient tracking is on, and do a pass over the data
        self.model.train(True)
        avg_loss = self.train_one_epoch(epoch_number)

        # We don't need gradients on to do reporting
        self.model.train(False)

        epoch_number += 1
    
    torch.save(self.model.state_dict(), self.save_path)
    
          
          

In [None]:


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

true_label = 1
false_label = 0

true_train_dataset = SpellCheckingDataset(tokenizer, [true_train_path], [true_label])
false_train_dataset = SpellCheckingDataset(tokenizer, [false_train_path], [false_label])

test_dataset = SpellCheckingDataset(tokenizer, 
                                    [true_test_path, false_test_path],
                                    [true_label, false_label])

true_saved_model_path = '/content/drive/MyDrive/models/true_bert.berm_lm'
false_saved_model_path = '/content/drive/MyDrive/models/false_bert.berm_lm'

true_trainer = SpellCheckingTrainer(TextClassificationModel(), 
                                    true_train_dataset,
                                    true_saved_model_path)

false_trainer = SpellCheckingTrainer(TextClassificationModel(), 
                                     false_train_dataset,
                                     false_saved_model_path)
 

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.trans

In [None]:
true_trainer.train()
false_trainer.train()




EPOCH 1:


  2%|▏         | 10/403 [00:03<02:13,  2.95it/s]

  batch 10 loss: 0.0809702505817
  batch 10 accuracy: 0.90625


  5%|▍         | 20/403 [00:06<02:03,  3.09it/s]

  batch 20 loss: 9.74941587150786e-05
  batch 20 accuracy: 1.0


  7%|▋         | 30/403 [00:10<02:06,  2.95it/s]

  batch 30 loss: 7.665900075088309e-05
  batch 30 accuracy: 1.0


 10%|▉         | 40/403 [00:13<02:01,  2.98it/s]

  batch 40 loss: 6.321273556593354e-05
  batch 40 accuracy: 1.0


 12%|█▏        | 50/403 [00:16<02:01,  2.91it/s]

  batch 50 loss: 5.304097863122118e-05
  batch 50 accuracy: 1.0


 15%|█▍        | 60/403 [00:20<01:56,  2.96it/s]

  batch 60 loss: 4.5547186205881256e-05
  batch 60 accuracy: 1.0


 17%|█▋        | 70/403 [00:23<01:49,  3.05it/s]

  batch 70 loss: 3.9787640177359925e-05
  batch 70 accuracy: 1.0


 20%|█▉        | 80/403 [00:26<01:51,  2.90it/s]

  batch 80 loss: 3.525267347145879e-05
  batch 80 accuracy: 1.0


 22%|██▏       | 90/403 [00:30<01:52,  2.79it/s]

  batch 90 loss: 3.170458740125791e-05
  batch 90 accuracy: 1.0


 25%|██▍       | 100/403 [00:33<01:43,  2.92it/s]

  batch 100 loss: 2.8935122533390755e-05
  batch 100 accuracy: 1.0


 27%|██▋       | 110/403 [00:37<01:43,  2.82it/s]

  batch 110 loss: 2.6636554094839917e-05
  batch 110 accuracy: 1.0


 30%|██▉       | 120/403 [00:40<01:41,  2.80it/s]

  batch 120 loss: 2.491654307511618e-05
  batch 120 accuracy: 1.0


 32%|███▏      | 130/403 [00:44<01:36,  2.83it/s]

  batch 130 loss: 2.305795099310669e-05
  batch 130 accuracy: 1.0


 35%|███▍      | 140/403 [00:47<01:31,  2.87it/s]

  batch 140 loss: 2.1738423486908814e-05
  batch 140 accuracy: 1.0


 37%|███▋      | 150/403 [00:51<01:29,  2.81it/s]

  batch 150 loss: 2.0418897750883825e-05
  batch 150 accuracy: 1.0


 40%|███▉      | 160/403 [00:55<01:27,  2.77it/s]

  batch 160 loss: 1.9377658321589114e-05
  batch 160 accuracy: 1.0


 42%|████▏     | 170/403 [00:58<01:20,  2.90it/s]

  batch 170 loss: 1.8309597432919868e-05
  batch 170 accuracy: 1.0


 45%|████▍     | 180/403 [01:02<01:15,  2.94it/s]

  batch 180 loss: 1.7340259636269857e-05
  batch 180 accuracy: 1.0


 47%|████▋     | 190/403 [01:05<01:19,  2.69it/s]

  batch 190 loss: 1.6409666386565475e-05
  batch 190 accuracy: 1.0


 50%|████▉     | 200/403 [01:09<01:14,  2.72it/s]

  batch 200 loss: 1.5627715316709767e-05
  batch 200 accuracy: 1.0


 52%|█████▏    | 210/403 [01:12<01:08,  2.81it/s]

  batch 210 loss: 1.4994406221393324e-05
  batch 210 accuracy: 1.0


 55%|█████▍    | 220/403 [01:16<01:06,  2.76it/s]

  batch 220 loss: 1.4252690068359558e-05
  batch 220 accuracy: 1.0


 57%|█████▋    | 230/403 [01:20<01:06,  2.59it/s]

  batch 230 loss: 1.3673771893515073e-05
  batch 230 accuracy: 1.0


 60%|█████▉    | 240/403 [01:23<00:57,  2.82it/s]

  batch 240 loss: 1.3108637743068882e-05
  batch 240 accuracy: 1.0


 62%|██████▏   | 250/403 [01:27<00:52,  2.90it/s]

  batch 250 loss: 1.2597148740918153e-05
  batch 250 accuracy: 1.0


 65%|██████▍   | 260/403 [01:30<00:50,  2.84it/s]

  batch 260 loss: 1.2024564569842986e-05
  batch 260 accuracy: 1.0


 67%|██████▋   | 270/403 [01:34<00:49,  2.67it/s]

  batch 270 loss: 1.1501155152340908e-05
  batch 270 accuracy: 1.0


 69%|██████▉   | 280/403 [01:38<00:45,  2.68it/s]

  batch 280 loss: 1.1081682658301286e-05
  batch 280 accuracy: 1.0


 72%|███████▏  | 290/403 [01:41<00:37,  2.99it/s]

  batch 290 loss: 1.0680464593949693e-05
  batch 290 accuracy: 1.0


 74%|███████▍  | 300/403 [01:45<00:34,  2.96it/s]

  batch 300 loss: 1.025801221713314e-05
  batch 300 accuracy: 1.0


 77%|███████▋  | 310/403 [01:48<00:32,  2.90it/s]

  batch 310 loss: 9.864990090937238e-06
  batch 310 accuracy: 1.0


 79%|███████▉  | 320/403 [01:52<00:28,  2.87it/s]

  batch 320 loss: 9.444773294338556e-06
  batch 320 accuracy: 1.0


 82%|████████▏ | 330/403 [01:55<00:24,  3.00it/s]

  batch 330 loss: 9.10316099819139e-06
  batch 330 accuracy: 1.0


 84%|████████▍ | 340/403 [01:58<00:21,  2.98it/s]

  batch 340 loss: 8.711256937489243e-06
  batch 340 accuracy: 1.0


 87%|████████▋ | 350/403 [02:02<00:18,  2.90it/s]

  batch 350 loss: 8.334254327664292e-06
  batch 350 accuracy: 1.0


 89%|████████▉ | 360/403 [02:05<00:14,  2.94it/s]

  batch 360 loss: 8.08316768123915e-06
  batch 360 accuracy: 1.0


 92%|█████████▏| 370/403 [02:09<00:11,  2.75it/s]

  batch 370 loss: 7.80898418846819e-06
  batch 370 accuracy: 1.0


 94%|█████████▍| 380/403 [02:12<00:07,  2.99it/s]

  batch 380 loss: 7.509841063534641e-06
  batch 380 accuracy: 1.0


 97%|█████████▋| 390/403 [02:16<00:04,  2.79it/s]

  batch 390 loss: 7.2412456212096496e-06
  batch 390 accuracy: 1.0


 99%|█████████▉| 400/403 [02:19<00:01,  2.72it/s]

  batch 400 loss: 7.040451081027228e-06
  batch 400 accuracy: 1.0


100%|██████████| 403/403 [02:21<00:00,  2.86it/s]


EPOCH 2:


  2%|▏         | 10/403 [00:03<02:24,  2.73it/s]

  batch 10 loss: 6.685428441123563e-06
  batch 10 accuracy: 1.0


  5%|▍         | 20/403 [00:07<02:22,  2.69it/s]

  batch 20 loss: 6.439557668171569e-06
  batch 20 accuracy: 1.0


  7%|▋         | 30/403 [00:10<02:15,  2.76it/s]

  batch 30 loss: 6.238018198450572e-06
  batch 30 accuracy: 1.0


 10%|▉         | 40/403 [00:14<02:10,  2.79it/s]

  batch 40 loss: 5.998108032219447e-06
  batch 40 accuracy: 1.0


 12%|█▏        | 50/403 [00:18<02:12,  2.66it/s]

  batch 50 loss: 5.739943863123175e-06
  batch 50 accuracy: 1.0


 15%|█▍        | 60/403 [00:21<02:01,  2.81it/s]

  batch 60 loss: 5.600244717112281e-06
  batch 60 accuracy: 1.0


 17%|█▋        | 70/403 [00:25<01:55,  2.89it/s]

  batch 70 loss: 5.362197343447454e-06
  batch 70 accuracy: 1.0


 20%|█▉        | 80/403 [00:28<01:50,  2.93it/s]

  batch 80 loss: 5.183754988794325e-06
  batch 80 accuracy: 1.0


 22%|██▏       | 90/403 [00:32<01:51,  2.81it/s]

  batch 90 loss: 5.026174404764259e-06
  batch 90 accuracy: 1.0


 25%|██▍       | 100/403 [00:36<01:58,  2.55it/s]

  batch 100 loss: 4.839163919218084e-06
  batch 100 accuracy: 1.0


 27%|██▋       | 110/403 [00:39<01:49,  2.67it/s]

  batch 110 loss: 4.668172264800533e-06
  batch 110 accuracy: 1.0


 30%|██▉       | 120/403 [00:43<01:39,  2.83it/s]

  batch 120 loss: 4.541884359620236e-06
  batch 120 accuracy: 1.0


 32%|███▏      | 130/403 [00:47<01:36,  2.84it/s]

  batch 130 loss: 4.364187216549926e-06
  batch 130 accuracy: 1.0


 35%|███▍      | 140/403 [00:50<01:28,  2.96it/s]

  batch 140 loss: 4.177176813471847e-06
  batch 140 accuracy: 1.0


 37%|███▋      | 150/403 [00:53<01:24,  2.98it/s]

  batch 150 loss: 4.037105333499139e-06
  batch 150 accuracy: 1.0


 40%|███▉      | 160/403 [00:57<01:24,  2.88it/s]

  batch 160 loss: 3.976382892578533e-06
  batch 160 accuracy: 1.0


 42%|████▏     | 170/403 [01:00<01:18,  2.97it/s]

  batch 170 loss: 3.7826670046668317e-06
  batch 170 accuracy: 1.0


 45%|████▍     | 180/403 [01:04<01:15,  2.95it/s]

  batch 180 loss: 3.6500461870483907e-06
  batch 180 accuracy: 1.0


 47%|████▋     | 190/403 [01:07<01:17,  2.75it/s]

  batch 190 loss: 3.541639860797967e-06
  batch 190 accuracy: 1.0


 50%|████▉     | 200/403 [01:11<01:14,  2.74it/s]

  batch 200 loss: 3.449624872567994e-06
  batch 200 accuracy: 1.0


 52%|█████▏    | 210/403 [01:14<01:06,  2.92it/s]

  batch 210 loss: 3.3333954101563583e-06
  batch 210 accuracy: 1.0


 55%|█████▍    | 220/403 [01:18<01:03,  2.90it/s]

  batch 220 loss: 3.255909160827826e-06
  batch 220 accuracy: 1.0


 57%|█████▋    | 230/403 [01:22<01:02,  2.75it/s]

  batch 230 loss: 3.1128575082579684e-06
  batch 230 accuracy: 1.0


 60%|█████▉    | 240/403 [01:25<00:59,  2.73it/s]

  batch 240 loss: 2.9970006208619344e-06
  batch 240 accuracy: 1.0


 62%|██████▏   | 250/403 [01:29<00:53,  2.86it/s]

  batch 250 loss: 2.916906645520316e-06
  batch 250 accuracy: 1.0


 65%|██████▍   | 260/403 [01:32<00:51,  2.77it/s]

  batch 260 loss: 2.8304796507141796e-06
  batch 260 accuracy: 1.0


 67%|██████▋   | 270/403 [01:36<00:47,  2.80it/s]

  batch 270 loss: 2.750758234844919e-06
  batch 270 accuracy: 1.0


 69%|██████▉   | 280/403 [01:39<00:44,  2.77it/s]

  batch 280 loss: 2.6337837761023924e-06
  batch 280 accuracy: 1.0


 72%|███████▏  | 290/403 [01:43<00:38,  2.95it/s]

  batch 290 loss: 2.544749108396631e-06
  batch 290 accuracy: 1.0


 74%|███████▍  | 300/403 [01:46<00:35,  2.94it/s]

  batch 300 loss: 2.508613733681888e-06
  batch 300 accuracy: 1.0


 77%|███████▋  | 310/403 [01:50<00:32,  2.83it/s]

  batch 310 loss: 2.430009903503232e-06
  batch 310 accuracy: 1.0


 79%|███████▉  | 320/403 [01:53<00:29,  2.85it/s]

  batch 320 loss: 2.33464223600622e-06
  batch 320 accuracy: 1.0


 82%|████████▏ | 330/403 [01:57<00:24,  2.92it/s]

  batch 330 loss: 2.280252878898963e-06
  batch 330 accuracy: 1.0


 84%|████████▍ | 340/403 [02:00<00:22,  2.78it/s]

  batch 340 loss: 2.186747873759711e-06
  batch 340 accuracy: 1.0


 87%|████████▋ | 350/403 [02:04<00:18,  2.85it/s]

  batch 350 loss: 2.133476103586534e-06
  batch 350 accuracy: 1.0


 89%|████████▉ | 360/403 [02:07<00:14,  2.97it/s]

  batch 360 loss: 2.1223002476571e-06
  batch 360 accuracy: 1.0


 92%|█████████▏| 370/403 [02:11<00:11,  2.79it/s]

  batch 370 loss: 2.050774511878096e-06
  batch 370 accuracy: 1.0


 94%|█████████▍| 380/403 [02:14<00:07,  2.89it/s]

  batch 380 loss: 1.9945225057357903e-06
  batch 380 accuracy: 1.0


 97%|█████████▋| 390/403 [02:18<00:04,  2.83it/s]

  batch 390 loss: 1.9021351151087097e-06
  batch 390 accuracy: 1.0


 99%|█████████▉| 400/403 [02:21<00:01,  2.90it/s]

  batch 400 loss: 1.8477457748325016e-06
  batch 400 accuracy: 1.0


100%|██████████| 403/403 [02:22<00:00,  2.82it/s]


EPOCH 3:


  2%|▏         | 10/403 [00:03<02:17,  2.85it/s]

  batch 10 loss: 1.772494771270465e-06
  batch 10 accuracy: 1.0


  5%|▍         | 20/403 [00:07<02:20,  2.73it/s]

  batch 20 loss: 1.7270461465843808e-06
  batch 20 accuracy: 1.0


  7%|▋         | 30/403 [00:10<02:16,  2.74it/s]

  batch 30 loss: 1.655892973963622e-06
  batch 30 accuracy: 1.0


 10%|▉         | 40/403 [00:14<02:02,  2.96it/s]

  batch 40 loss: 1.6205026672117424e-06
  batch 40 accuracy: 1.0


 12%|█▏        | 50/403 [00:17<02:04,  2.83it/s]

  batch 50 loss: 1.5687210426455897e-06
  batch 50 accuracy: 1.0


 15%|█▍        | 60/403 [00:21<01:56,  2.95it/s]

  batch 60 loss: 1.5676034761840793e-06
  batch 60 accuracy: 1.0


 17%|█▋        | 70/403 [00:24<02:01,  2.74it/s]

  batch 70 loss: 1.4569621573880241e-06
  batch 70 accuracy: 1.0


 20%|█▉        | 80/403 [00:28<01:56,  2.77it/s]

  batch 80 loss: 1.4565896335318352e-06
  batch 80 accuracy: 1.0


 22%|██▏       | 90/403 [00:31<01:48,  2.89it/s]

  batch 90 loss: 1.400337669889054e-06
  batch 90 accuracy: 1.0


 25%|██▍       | 100/403 [00:35<01:45,  2.88it/s]

  batch 100 loss: 1.3396153419896554e-06
  batch 100 accuracy: 1.0


 27%|██▋       | 110/403 [00:38<01:50,  2.65it/s]

  batch 110 loss: 1.3165185098094439e-06
  batch 110 accuracy: 1.0


 30%|██▉       | 120/403 [00:42<01:41,  2.78it/s]

  batch 120 loss: 1.2647369074922433e-06
  batch 120 accuracy: 1.0


 32%|███▏      | 130/403 [00:45<01:37,  2.81it/s]

  batch 130 loss: 1.2692072694839648e-06
  batch 130 accuracy: 1.0


 35%|███▍      | 140/403 [00:49<01:37,  2.70it/s]

  batch 140 loss: 1.2487181557671881e-06
  batch 140 accuracy: 1.0


 37%|███▋      | 150/403 [00:53<01:29,  2.82it/s]

  batch 150 loss: 1.1812903030213163e-06
  batch 150 accuracy: 1.0


 40%|███▉      | 160/403 [00:56<01:27,  2.77it/s]

  batch 160 loss: 1.1790551351039134e-06
  batch 160 accuracy: 1.0


 42%|████▏     | 170/403 [01:00<01:20,  2.91it/s]

  batch 170 loss: 1.1559583146476969e-06
  batch 170 accuracy: 1.0


 45%|████▍     | 180/403 [01:03<01:16,  2.91it/s]

  batch 180 loss: 1.1239207708741612e-06
  batch 180 accuracy: 1.0


 47%|████▋     | 190/403 [01:07<01:15,  2.82it/s]

  batch 190 loss: 1.0736292957341923e-06
  batch 190 accuracy: 1.0


 50%|████▉     | 200/403 [01:10<01:17,  2.60it/s]

  batch 200 loss: 1.0378664679682866e-06
  batch 200 accuracy: 1.0


 52%|█████▏    | 210/403 [01:14<01:07,  2.84it/s]

  batch 210 loss: 1.0337686611855444e-06
  batch 210 accuracy: 1.0


 55%|█████▍    | 220/403 [01:17<01:05,  2.77it/s]

  batch 220 loss: 1.0289257843082594e-06
  batch 220 accuracy: 1.0


 57%|█████▋    | 230/403 [01:21<01:05,  2.64it/s]

  batch 230 loss: 1.0117894247324221e-06
  batch 230 accuracy: 1.0


 60%|█████▉    | 240/403 [01:25<00:56,  2.90it/s]

  batch 240 loss: 9.682034777252073e-07
  batch 240 accuracy: 1.0


 62%|██████▏   | 250/403 [01:28<00:53,  2.87it/s]

  batch 250 loss: 9.268527051191038e-07
  batch 250 accuracy: 1.0


 65%|██████▍   | 260/403 [01:32<00:51,  2.80it/s]

  batch 260 loss: 8.840118256751149e-07
  batch 260 accuracy: 1.0


 67%|██████▋   | 270/403 [01:35<00:49,  2.67it/s]

  batch 270 loss: 8.527193459564894e-07
  batch 270 accuracy: 1.0


 69%|██████▉   | 280/403 [01:39<00:45,  2.72it/s]

  batch 280 loss: 8.702282411124897e-07
  batch 280 accuracy: 1.0


 72%|███████▏  | 290/403 [01:43<00:41,  2.70it/s]

  batch 290 loss: 8.735810212771479e-07
  batch 290 accuracy: 1.0


 74%|███████▍  | 300/403 [01:46<00:34,  2.95it/s]

  batch 300 loss: 8.501116736340544e-07
  batch 300 accuracy: 1.0


 77%|███████▋  | 310/403 [01:50<00:35,  2.64it/s]

  batch 310 loss: 7.998201982276162e-07
  batch 310 accuracy: 1.0


 79%|███████▉  | 320/403 [01:53<00:29,  2.77it/s]

  batch 320 loss: 8.210543866574756e-07
  batch 320 accuracy: 1.0


 82%|████████▏ | 330/403 [01:56<00:24,  2.96it/s]

  batch 330 loss: 7.834289118105511e-07
  batch 330 accuracy: 1.0


 84%|████████▍ | 340/403 [02:00<00:22,  2.75it/s]

  batch 340 loss: 7.480386129634634e-07
  batch 340 accuracy: 1.0


 87%|████████▋ | 350/403 [02:04<00:18,  2.91it/s]

  batch 350 loss: 7.025900263359417e-07
  batch 350 accuracy: 1.0


 89%|████████▉ | 360/403 [02:07<00:15,  2.81it/s]

  batch 360 loss: 7.167461566039926e-07
  batch 360 accuracy: 1.0


 92%|█████████▏| 370/403 [02:11<00:10,  3.00it/s]

  batch 370 loss: 7.182362761428684e-07
  batch 370 accuracy: 1.0


 94%|█████████▍| 380/403 [02:14<00:08,  2.75it/s]

  batch 380 loss: 6.888064554957071e-07
  batch 380 accuracy: 1.0


 97%|█████████▋| 390/403 [02:18<00:04,  2.83it/s]

  batch 390 loss: 6.776305777613422e-07
  batch 390 accuracy: 1.0


 99%|█████████▉| 400/403 [02:21<00:01,  2.75it/s]

  batch 400 loss: 6.51926060291254e-07
  batch 400 accuracy: 1.0


100%|██████████| 403/403 [02:22<00:00,  2.82it/s]


EPOCH 4:


  2%|▏         | 10/403 [00:03<02:27,  2.67it/s]

  batch 10 loss: 6.482007689713859e-07
  batch 10 accuracy: 1.0


  5%|▍         | 20/403 [00:07<02:19,  2.75it/s]

  batch 20 loss: 6.098302482772481e-07
  batch 20 accuracy: 1.0


  7%|▋         | 30/403 [00:10<02:14,  2.78it/s]

  batch 30 loss: 6.042423091658148e-07
  batch 30 accuracy: 1.0


 10%|▉         | 40/403 [00:14<02:05,  2.90it/s]

  batch 40 loss: 6.049873679360492e-07
  batch 40 accuracy: 1.0


 12%|█▏        | 50/403 [00:18<02:08,  2.75it/s]

  batch 50 loss: 6.161632537084413e-07
  batch 50 accuracy: 1.0


 15%|█▍        | 60/403 [00:21<02:01,  2.83it/s]

  batch 60 loss: 5.997719632057497e-07
  batch 60 accuracy: 1.0


 17%|█▋        | 70/403 [00:25<01:54,  2.90it/s]

  batch 70 loss: 5.640091397031116e-07
  batch 70 accuracy: 1.0


 20%|█▉        | 80/403 [00:28<01:56,  2.76it/s]

  batch 80 loss: 5.360694434354087e-07
  batch 80 accuracy: 1.0


 22%|██▏       | 90/403 [00:32<01:55,  2.71it/s]

  batch 90 loss: 5.669893927697029e-07
  batch 90 accuracy: 1.0


 25%|██▍       | 100/403 [00:35<01:45,  2.88it/s]

  batch 100 loss: 5.193056360043187e-07
  batch 100 accuracy: 1.0


 27%|██▋       | 110/403 [00:39<01:44,  2.80it/s]

  batch 110 loss: 5.573036450547344e-07
  batch 110 accuracy: 1.0


 30%|██▉       | 120/403 [00:42<01:38,  2.87it/s]

  batch 120 loss: 5.16697931796189e-07
  batch 120 accuracy: 1.0


 32%|███▏      | 130/403 [00:46<01:39,  2.74it/s]

  batch 130 loss: 4.865230673890375e-07
  batch 130 accuracy: 1.0


 35%|███▍      | 140/403 [00:49<01:30,  2.91it/s]

  batch 140 loss: 5.01424241683529e-07
  batch 140 accuracy: 1.0


 37%|███▋      | 150/403 [00:53<01:32,  2.73it/s]

  batch 150 loss: 4.85405481254291e-07
  batch 150 accuracy: 1.0


 40%|███▉      | 160/403 [00:57<01:27,  2.79it/s]

  batch 160 loss: 4.783274279774701e-07
  batch 160 accuracy: 1.0


 42%|████▏     | 170/403 [01:00<01:23,  2.78it/s]

  batch 170 loss: 4.83542839635576e-07
  batch 170 accuracy: 1.0


 45%|████▍     | 180/403 [01:04<01:19,  2.81it/s]

  batch 180 loss: 4.541130256053582e-07
  batch 180 accuracy: 1.0


 47%|████▋     | 190/403 [01:07<01:17,  2.74it/s]

  batch 190 loss: 4.4256461317899814e-07
  batch 190 accuracy: 1.0


 50%|████▉     | 200/403 [01:11<01:12,  2.80it/s]

  batch 200 loss: 4.2691838042513215e-07
  batch 200 accuracy: 1.0


 52%|█████▏    | 210/403 [01:14<01:08,  2.81it/s]

  batch 210 loss: 4.287810296377856e-07
  batch 210 accuracy: 1.0


 55%|█████▍    | 220/403 [01:18<01:05,  2.80it/s]

  batch 220 loss: 4.3846679631539114e-07
  batch 220 accuracy: 1.0


 57%|█████▋    | 230/403 [01:22<01:02,  2.79it/s]

  batch 230 loss: 4.0791939082134485e-07
  batch 230 accuracy: 1.0


 60%|█████▉    | 240/403 [01:25<00:56,  2.87it/s]

  batch 240 loss: 4.217029836440376e-07
  batch 240 accuracy: 1.0


 62%|██████▏   | 250/403 [01:29<00:53,  2.86it/s]

  batch 250 loss: 3.9562593205926813e-07
  batch 250 accuracy: 1.0


 65%|██████▍   | 260/403 [01:32<00:49,  2.89it/s]

  batch 260 loss: 3.915281158617959e-07
  batch 260 accuracy: 1.0


 67%|██████▋   | 270/403 [01:35<00:47,  2.82it/s]

  batch 270 loss: 4.094095245266853e-07
  batch 270 accuracy: 1.0


 69%|██████▉   | 280/403 [01:39<00:45,  2.68it/s]

  batch 280 loss: 3.5613783997416593e-07
  batch 280 accuracy: 1.0


 72%|███████▏  | 290/403 [01:43<00:40,  2.82it/s]

  batch 290 loss: 3.4719714059332174e-07
  batch 290 accuracy: 1.0


 74%|███████▍  | 300/403 [01:46<00:34,  2.98it/s]

  batch 300 loss: 3.408641454206032e-07
  batch 300 accuracy: 1.0


 77%|███████▋  | 310/403 [01:50<00:31,  2.92it/s]

  batch 310 loss: 3.390015039795263e-07
  batch 310 accuracy: 1.0


 79%|███████▉  | 320/403 [01:53<00:28,  2.92it/s]

  batch 320 loss: 3.360212736059334e-07
  batch 320 accuracy: 1.0


 82%|████████▏ | 330/403 [01:57<00:24,  2.92it/s]

  batch 330 loss: 3.192574493438331e-07
  batch 330 accuracy: 1.0


 84%|████████▍ | 340/403 [02:00<00:22,  2.86it/s]

  batch 340 loss: 3.144145742873074e-07
  batch 340 accuracy: 1.0


 87%|████████▋ | 350/403 [02:04<00:19,  2.77it/s]

  batch 350 loss: 3.054738676677978e-07
  batch 350 accuracy: 1.0


 89%|████████▉ | 360/403 [02:07<00:14,  2.93it/s]

  batch 360 loss: 2.9988592797904877e-07
  batch 360 accuracy: 1.0


 92%|█████████▏| 370/403 [02:11<00:11,  2.76it/s]

  batch 370 loss: 3.1515963208054543e-07
  batch 370 accuracy: 1.0


 94%|█████████▍| 380/403 [02:14<00:08,  2.79it/s]

  batch 380 loss: 3.0659145602299094e-07
  batch 380 accuracy: 1.0


 97%|█████████▋| 390/403 [02:18<00:04,  2.77it/s]

  batch 390 loss: 2.928078780773126e-07
  batch 390 accuracy: 1.0


 99%|█████████▉| 400/403 [02:21<00:01,  2.82it/s]

  batch 400 loss: 2.913177591601613e-07
  batch 400 accuracy: 1.0


100%|██████████| 403/403 [02:22<00:00,  2.82it/s]


EPOCH 5:


  2%|▏         | 10/403 [00:03<02:23,  2.75it/s]

  batch 10 loss: 2.946705243145616e-07
  batch 10 accuracy: 1.0


  5%|▍         | 20/403 [00:07<02:22,  2.68it/s]

  batch 20 loss: 3.296882739479023e-07
  batch 20 accuracy: 1.0


  7%|▋         | 30/403 [00:10<02:19,  2.68it/s]

  batch 30 loss: 3.0510134452214245e-07
  batch 30 accuracy: 1.0


 10%|▉         | 40/403 [00:14<02:20,  2.58it/s]

  batch 40 loss: 3.088266356643724e-07
  batch 40 accuracy: 1.0


 12%|█▏        | 50/403 [00:17<01:59,  2.95it/s]

  batch 50 loss: 2.9802328418429226e-07
  batch 50 accuracy: 1.0


 15%|█▍        | 60/403 [00:21<02:04,  2.76it/s]

  batch 60 loss: 3.024936461760061e-07
  batch 60 accuracy: 1.0


 17%|█▋        | 70/403 [00:25<01:59,  2.79it/s]

  batch 70 loss: 2.8312211783900716e-07
  batch 70 accuracy: 1.0


 20%|█▉        | 80/403 [00:28<01:50,  2.92it/s]

  batch 80 loss: 2.8423971014660046e-07
  batch 80 accuracy: 1.0


 22%|██▏       | 90/403 [00:32<01:56,  2.68it/s]

  batch 90 loss: 2.7976935810250133e-07
  batch 90 accuracy: 1.0


 25%|██▍       | 100/403 [00:35<01:54,  2.65it/s]

  batch 100 loss: 2.592802518464224e-07
  batch 100 accuracy: 1.0


 27%|██▋       | 110/403 [00:39<01:46,  2.76it/s]

  batch 110 loss: 2.633780776362341e-07
  batch 110 accuracy: 1.0


 30%|██▉       | 120/403 [00:42<01:39,  2.86it/s]

  batch 120 loss: 2.522022118922933e-07
  batch 120 accuracy: 1.0


 32%|███▏      | 130/403 [00:46<01:31,  2.98it/s]

  batch 130 loss: 2.458692110796317e-07
  batch 130 accuracy: 1.0


 35%|███▍      | 140/403 [00:49<01:34,  2.77it/s]

  batch 140 loss: 2.2388998839330525e-07
  batch 140 accuracy: 1.0


 37%|███▋      | 150/403 [00:53<01:26,  2.92it/s]

  batch 150 loss: 2.156943492481927e-07
  batch 150 accuracy: 1.0


 40%|███▉      | 160/403 [00:56<01:22,  2.94it/s]

  batch 160 loss: 2.160668771011998e-07
  batch 160 accuracy: 1.0


 42%|████▏     | 170/403 [01:00<01:23,  2.79it/s]

  batch 170 loss: 2.0451847413395578e-07
  batch 170 accuracy: 1.0


 45%|████▍     | 180/403 [01:04<01:20,  2.76it/s]

  batch 180 loss: 1.9967559485857515e-07
  batch 180 accuracy: 1.0


 47%|████▋     | 190/403 [01:07<01:14,  2.87it/s]

  batch 190 loss: 1.8775466243960054e-07
  batch 190 accuracy: 1.0


 50%|████▉     | 200/403 [01:11<01:10,  2.87it/s]

  batch 200 loss: 2.1085147472456994e-07
  batch 200 accuracy: 1.0


 52%|█████▏    | 210/403 [01:14<01:07,  2.85it/s]

  batch 210 loss: 1.7844143200830067e-07
  batch 210 accuracy: 1.0


 55%|█████▍    | 220/403 [01:18<01:05,  2.78it/s]

  batch 220 loss: 1.8291178320862792e-07
  batch 220 accuracy: 1.0


 57%|█████▋    | 230/403 [01:22<01:01,  2.81it/s]

  batch 230 loss: 1.8589201944420422e-07
  batch 230 accuracy: 1.0


 60%|█████▉    | 240/403 [01:25<00:58,  2.78it/s]

  batch 240 loss: 2.0228330284146145e-07
  batch 240 accuracy: 1.0


 62%|██████▏   | 250/403 [01:29<00:52,  2.92it/s]

  batch 250 loss: 1.9036236913462658e-07
  batch 250 accuracy: 1.0


 65%|██████▍   | 260/403 [01:32<00:50,  2.83it/s]

  batch 260 loss: 1.6652049950050734e-07
  batch 260 accuracy: 1.0


 67%|██████▋   | 270/403 [01:36<00:47,  2.82it/s]

  batch 270 loss: 1.739710847603713e-07
  batch 270 accuracy: 1.0


 69%|██████▉   | 280/403 [01:39<00:42,  2.89it/s]

  batch 280 loss: 1.8142167330649972e-07
  batch 280 accuracy: 1.0


 72%|███████▏  | 290/403 [01:43<00:40,  2.82it/s]

  batch 290 loss: 1.881271928239192e-07
  batch 290 accuracy: 1.0


 74%|███████▍  | 300/403 [01:46<00:35,  2.89it/s]

  batch 300 loss: 1.6912820490767376e-07
  batch 300 accuracy: 1.0


 77%|███████▋  | 310/403 [01:50<00:34,  2.69it/s]

  batch 310 loss: 1.6763809074228432e-07
  batch 310 accuracy: 1.0


 79%|███████▉  | 320/403 [01:53<00:28,  2.93it/s]

  batch 320 loss: 1.6428532758628856e-07
  batch 320 accuracy: 1.0


 82%|████████▏ | 330/403 [01:57<00:24,  2.96it/s]

  batch 330 loss: 1.4826657066536097e-07
  batch 330 accuracy: 1.0


 84%|████████▍ | 340/403 [02:00<00:22,  2.85it/s]

  batch 340 loss: 1.6056003497856136e-07
  batch 340 accuracy: 1.0


 87%|████████▋ | 350/403 [02:04<00:18,  2.86it/s]

  batch 350 loss: 1.6205014985449346e-07
  batch 350 accuracy: 1.0


 89%|████████▉ | 360/403 [02:07<00:15,  2.79it/s]

  batch 360 loss: 1.5720727328806068e-07
  batch 360 accuracy: 1.0


 92%|█████████▏| 370/403 [02:11<00:11,  2.89it/s]

  batch 370 loss: 1.739710860038223e-07
  batch 370 accuracy: 1.0


 94%|█████████▍| 380/403 [02:14<00:07,  2.91it/s]

  batch 380 loss: 1.6428532891855942e-07
  batch 380 accuracy: 1.0


 97%|█████████▋| 390/403 [02:18<00:04,  2.90it/s]

  batch 390 loss: 1.631677377211889e-07
  batch 390 accuracy: 1.0


 99%|█████████▉| 400/403 [02:21<00:01,  2.81it/s]

  batch 400 loss: 1.6503038529070841e-07
  batch 400 accuracy: 1.0


100%|██████████| 403/403 [02:22<00:00,  2.82it/s]


EPOCH 6:


  2%|▏         | 10/403 [00:03<02:17,  2.85it/s]

  batch 10 loss: 1.6912820699489545e-07
  batch 10 accuracy: 1.0


  5%|▍         | 20/403 [00:06<02:08,  2.97it/s]

  batch 20 loss: 1.6428532754187936e-07
  batch 20 accuracy: 1.0


  7%|▋         | 30/403 [00:10<02:06,  2.95it/s]

  batch 30 loss: 1.7434361461178357e-07
  batch 30 accuracy: 1.0


 10%|▉         | 40/403 [00:13<02:05,  2.89it/s]

  batch 40 loss: 1.7248096624290227e-07
  batch 40 accuracy: 1.0


 12%|█▏        | 50/403 [00:17<02:03,  2.87it/s]

  batch 50 loss: 1.6056003302456612e-07
  batch 50 accuracy: 1.0


 15%|█▍        | 60/403 [00:20<01:56,  2.94it/s]

  batch 60 loss: 1.6912820774984882e-07
  batch 60 accuracy: 1.0


 17%|█▋        | 70/403 [00:24<02:03,  2.69it/s]

  batch 70 loss: 1.5459956881348257e-07
  batch 70 accuracy: 1.0


 20%|█▉        | 80/403 [00:28<01:51,  2.89it/s]

  batch 80 loss: 1.6056003524501452e-07
  batch 80 accuracy: 1.0


 22%|██▏       | 90/403 [00:31<01:47,  2.91it/s]

  batch 90 loss: 1.609325630092047e-07
  batch 90 accuracy: 1.0


 25%|██▍       | 100/403 [00:35<01:48,  2.79it/s]

  batch 100 loss: 1.6652050123245785e-07
  batch 100 accuracy: 1.0


 27%|██▋       | 110/403 [00:38<01:46,  2.74it/s]

  batch 110 loss: 1.4714898266543875e-07
  batch 110 accuracy: 1.0


 30%|██▉       | 120/403 [00:42<01:38,  2.87it/s]

  batch 120 loss: 1.6093256545169844e-07
  batch 120 accuracy: 1.0


 32%|███▏      | 130/403 [00:45<01:40,  2.73it/s]

  batch 130 loss: 1.5087427727157007e-07
  batch 130 accuracy: 1.0


 35%|███▍      | 140/403 [00:49<01:34,  2.80it/s]

  batch 140 loss: 1.5832485737999335e-07
  batch 140 accuracy: 1.0


 37%|███▋      | 150/403 [00:53<01:35,  2.65it/s]

  batch 150 loss: 1.437962197314877e-07
  batch 150 accuracy: 1.0


 40%|███▉      | 160/403 [00:56<01:22,  2.94it/s]

  batch 160 loss: 1.6316773714387293e-07
  batch 160 accuracy: 1.0


 42%|████▏     | 170/403 [01:00<01:21,  2.84it/s]

  batch 170 loss: 1.4677645387984223e-07
  batch 170 accuracy: 1.0


 45%|████▍     | 180/403 [01:03<01:21,  2.73it/s]

  batch 180 loss: 1.6093256443029256e-07
  batch 180 accuracy: 1.0


 47%|████▋     | 190/403 [01:07<01:16,  2.78it/s]

  batch 190 loss: 1.6018750481628608e-07
  batch 190 accuracy: 1.0


 50%|████▉     | 200/403 [01:11<01:17,  2.62it/s]

  batch 200 loss: 1.5534462425304533e-07
  batch 200 accuracy: 1.0


 52%|█████▏    | 210/403 [01:14<01:09,  2.78it/s]

  batch 210 loss: 1.6130509303825356e-07
  batch 210 accuracy: 1.0


 55%|█████▍    | 220/403 [01:18<01:05,  2.78it/s]

  batch 220 loss: 1.5795232832794295e-07
  batch 220 accuracy: 1.0


 57%|█████▋    | 230/403 [01:21<00:59,  2.88it/s]

  batch 230 loss: 1.5459956517194693e-07
  batch 230 accuracy: 1.0


 60%|█████▉    | 240/403 [01:25<00:55,  2.92it/s]

  batch 240 loss: 1.5124680401435494e-07
  batch 240 accuracy: 1.0


 62%|██████▏   | 250/403 [01:28<00:53,  2.87it/s]

  batch 250 loss: 1.5087427491789484e-07
  batch 250 accuracy: 1.0


 65%|██████▍   | 260/403 [01:32<00:51,  2.76it/s]

  batch 260 loss: 1.4081598567195124e-07
  batch 260 accuracy: 1.0


 67%|██████▋   | 270/403 [01:36<00:48,  2.72it/s]

  batch 270 loss: 1.5869738629881732e-07
  batch 270 accuracy: 1.0


 69%|██████▉   | 280/403 [01:39<00:43,  2.83it/s]

  batch 280 loss: 1.545995661933532e-07
  batch 280 accuracy: 1.0


 72%|███████▏  | 290/403 [01:43<00:38,  2.93it/s]

  batch 290 loss: 1.5012921739110976e-07
  batch 290 accuracy: 1.0


 74%|███████▍  | 300/403 [01:46<00:37,  2.76it/s]

  batch 300 loss: 1.4528633847100088e-07
  batch 300 accuracy: 1.0


 77%|███████▋  | 310/403 [01:50<00:33,  2.78it/s]

  batch 310 loss: 1.3597310888347206e-07
  batch 310 accuracy: 1.0


 79%|███████▊  | 317/403 [01:52<00:30,  2.85it/s]