%pip install numpy==1.26.2

In [1]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

import logging

debug = logging.getLogger("Debug")
info  = print
plt.ion()   # interactive mode

<contextlib.ExitStack at 0x2b1372353a50>

In [2]:
#check GPU
device = None
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Running CUDA Mode:", device, torch.cuda.get_device_name(0))
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Running MPS Mode:", device)
else:
    device = torch.device("cpu")
    print("Running CPU Mode:", device)



Running CUDA Mode: cuda NVIDIA A100-PCIE-40GB


## Data and Classes
- Create Dataloader class

Note: Working on Part (a) as of now.  
Guiding light: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

In [3]:
START_TOKEN = "START"
END_TOKEN = "END"
UNK_TOKEN = "UNK"

# MAX_EXAMPLES = 100
class Vocabulary:
    def __init__(self, freq_dict, wd_to_id, id_to_wd):
        self.freq_dict = freq_dict
        self.wd_to_id = wd_to_id
        self.id_to_wd = id_to_wd
        self.N = len(freq_dict)
    
    def get_id(self, word):
        if word in self.wd_to_id:
            return self.wd_to_id[word]
        else:
            return self.wd_to_id[UNK_TOKEN]

class LatexFormulaDataset(Dataset):
    """Latex Formula Dataset: Image and Text"""
    
    def __init__(self, csv_file, root_dir, max_examples=None, transform = None):
        """
        Arguments:
            csv_file (string): Path to the csv file with image name and text
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        #@TODO: May want to preload images
        info("Loading Dataset...")
        self.df = pd.read_csv(csv_file)
        info("Loaded Dataset", self.df.info)
        
        #Slice the dataset if max_examples is not None
        if max_examples is not None:
            self.df = self.df.iloc[:max_examples, :]

        self.root_dir = root_dir
        self.transform = transform
        self.vocab= self.construct_vocab()  

        self.maxlen = 0
        for formula in self.df['formula']:
            if len(formula) > self.maxlen:
                self.maxlen = len(formula)

        self.df['formula'] = self.df['formula'].apply(lambda x: [START_TOKEN] + x + [END_TOKEN] + [UNK_TOKEN]*(self.maxlen - len(x)))
        self.maxlen += 2 #for start and end tokens

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        """
        Returns sample of type image, textformula
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.df.iloc[idx, 0])
        image = io.imread(img_name)
        formula = self.df.iloc[idx, 1]
        formula = np.array([formula], dtype=str).reshape(-1, 1)
        formula = [[self.vocab.get_id(wd[0]) for wd in formula]] 
        sample = {'image': image, 'formula': torch.tensor(formula, dtype=torch.int64)}

        if self.transform:
            sample['image'] = self.transform(sample['image'])
            
        return sample 
    
    def construct_vocab(self):
        """
        Constructs vocabulary from the dataset formulas
        """
        #Split on spaces to tokenize
        self.df['formula'] = self.df['formula'].apply(lambda x: x.split())

        freq_dict = {}
        for formula in self.df['formula']:
            for wd in formula:
                if wd not in freq_dict:
                    freq_dict[wd] = 1
                else:
                    freq_dict[wd] += 1
        freq_dict[START_TOKEN] = 1
        freq_dict[END_TOKEN] = 1
        freq_dict[UNK_TOKEN] = 1
        N = len(freq_dict)
        wd_to_id = {}
        for i, wd in enumerate(freq_dict):
            wd_to_id[wd] = i
        id_to_wd = {v: k for k, v in wd_to_id.items()}
    
        #pad the formulas with 
        return Vocabulary(freq_dict, wd_to_id, id_to_wd)      

def get_dataloader(csv_path, image_root, batch_size, transform = None, max_examples = None):
    """
    Returns dataloader,dataset for the dataset
    """
    dataset = LatexFormulaDataset(csv_path, image_root, max_examples=max_examples,transform=transform) #checked
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader, dataset
     

### Encoder Network
- A CNN to encode image to more meaningful vector

In [4]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super().__init__()
    
        #@TODO:reduce number of layers: eliminate pools and acts
        self.conv1 = nn.Conv2d(3, 32, (5, 5))
        self.act1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d((2, 2))
        
        self.conv2 = nn.Conv2d(32, 64, (5, 5))
        self.act2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d((2, 2))
        
        self.conv3 = nn.Conv2d(64, 128, (5, 5))
        self.act3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d((2, 2))
        
        self.conv4 = nn.Conv2d(128, 256, (5, 5))
        self.act4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d((2, 2))
        
        self.conv5 = nn.Conv2d(256, 512, (5, 5))
        self.act5 = nn.ReLU()
        self.pool5 = nn.MaxPool2d((2, 2))
        
        self.avg_pool = nn.AvgPool2d((3, 3))
    
    def forward(self, x):
        x = self.act1(self.conv1(x))
        x = self.pool1(x)
        
        x = self.act2(self.conv2(x))
        x = self.pool2(x)
        
        x = self.act3(self.conv3(x))
        x = self.pool3(x)
        
        x = self.act4(self.conv4(x))
        x = self.pool4(x)
        
        x = self.act5(self.conv5(x))
        x = self.pool5(x)
        
        x = self.avg_pool(x)
        x = x.view(-1,512) 
        # info(f"Encoder Output Shape: {x.shape}")
        return x

### Vocabulary
- https://github.com/harvardnlp/im2markup/blob/master

### Decoder Network

In [5]:
class Decoder(nn.Module):
    """
    Inputs:
    (here M is whatever the batch size is passed)

    context_size : size of the context vector [shape: (1,M,context_size)]
    n_layers: number of layers [for our purposes, defaults to 1]
    hidden_size : size of the hidden state vectors [shape: (n_layers,M,hidden_size)]
    embed_size : size of the embedding vectors [shape: (1,M,embed_size)]
    vocab_size : size of the vocabulary
    max_length : maximum length of the formula
    """
    def __init__(self, context_size, vocab, n_layers = 1, hidden_size = 512, embed_size = 512,  max_length = 100):
        super().__init__()
        self.context_size = context_size
        self.vocab = vocab
        self.vocab_size = vocab.N
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.max_length = max_length


        self.input_size = context_size + embed_size

        self.embed = nn.Embedding(self.vocab_size, embed_size)
        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        self.linear = nn.Linear(hidden_size, self.vocab_size)
        self.softmax = nn.Softmax(dim = 1)
    
    def forward(self, context, target_tensor = None):
        """
        M: batch_size
        context is the context vector from the encoder [shape: (M,context_size)]
        target_tensor is the formula in tensor form [shape: (M,max_length)] (in the second dimension, it is sequence of indices of formula tokens)
            if target_tensor is not None, then we are in Teacher Forcing mode
            else normal jo bhi (last prediction ka embed is concatenated)
        """
        # info("Decoder Forward")
        # info(f"Context shape: {context.shape}")
        context.to(device)
        target_tensor = target_tensor.squeeze()
        batch_size = context.shape[0]

        #initialize hidden state and cell state
            #@TODO: Some caveat in the size of the cell state. Should it be same as hidden_size? (check nn.LSTM documentation)
        hidden = torch.zeros((batch_size, self.hidden_size)).to(context.device)
        cell = torch.zeros((batch_size, self.hidden_size)).to(context.device)

        #initialize the input with embedding of the start token
        init_embed = self.embed(torch.tensor([self.vocab.wd_to_id[START_TOKEN]]).to(device)).reshape((1, self.embed_size))
        init_embed = torch.repeat_interleave(init_embed, batch_size, dim = 0).to(context.device)

        # info(f"Initial Embedding Shape: {init_embed.shape}")

        input = torch.cat([context, init_embed], dim = 1).to(context.device)

        #initialize the output_history and init_output
        outputs = []
        output = torch.zeros((batch_size, self.vocab_size)).to(context.device)
        
        
        for i in range(self.max_length):
            hidden, cell = self.lstm(input, (hidden, cell))
            output = self.linear(hidden)
            # output = self.softmax(output)
            outputs.append(output)
            
            #teacher forcing: 50% times
            r = torch.rand(1)
            if r>0.5:
                embedding = self.embed(target_tensor[:, i]).reshape((batch_size, self.embed_size)).to(context.device)
                input = torch.cat([context, embedding], dim = 1).to(context.device)              
            else:
                #add the embedding of the last prediction
                input = torch.cat([context, self.embed(torch.argmax(output, dim = 1))], dim = 1).to(context.device)
        # info(f"Outputs: {outputs}")
        return torch.stack(outputs).to(context.device), hidden, cell

### Utility Functions

In [6]:
import time
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from tqdm import tqdm

plt.switch_backend('agg')
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)
def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)
def saveModel(save_path, model_state, optimiser_state, loss):
    torch.save({
            'model_state_dict': model_state,
            'optimizer_state_dict':optimiser_state,
            'loss': loss,  
    }, save_path)
    

### Training Code.
- Dataloader automatically loads in batches. The data need not be modified by us.

In [7]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):
    total_loss = 0
    idx = 0
    for data in dataloader:
        idx+=1
        
        info(f"----Batch {idx}----")
        input_tensor, target_tensor = data['image'].to(device), data['formula'].to(device)

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        
        encoder_output = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_output, target_tensor)
        
        # print(encoder_output.device, 'My device')
        
        # print(f"Decoder OutDim : {decoder_outputs.shape}, Target Tensor Dim: {target_tensor.shape}")
        # print(f"Target tensor: {target_tensor[0][0]}")
        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()
        
        print(f'Loss for batch {idx} = {loss.item()}')

    return total_loss / len(dataloader)

def train(train_dataloader, encoder, decoder, n_epochs, save_interval = 2, learning_rate=0.001, print_every=1, plot_every=5):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss().to(device) #as stated in assignment
    
    # Print model's device
    # print("Encoder's device:", next(encoder.parameters()).device)

    pb = tqdm(range(1, n_epochs + 1))
    for epoch in pb:
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            
        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
        
        if epoch % save_interval == 0:
            saveModel(f'checkpoints/encoder_epoch_{epoch}.pt', encoder.state_dict(), encoder_optimizer.state_dict(), loss)
            saveModel(f'checkpoints/decoder_epoch_{epoch}.pt', decoder.state_dict(), decoder_optimizer.state_dict(), loss)
            

        pb.set_description('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs), epoch, epoch / n_epochs * 100, print_loss_avg))
        
    showPlot(plot_losses)

## Training

In [8]:
batch_size = 32
vocab_size = 1000
CONTEXT_SIZE = 512
HIDDEN_SIZE = 512
EMBED_SIZE = 512
MAX_EXAMPLES = 1000
# image processing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda x: x/255.0), #min-max normalisation
])

### Load Dataset and Dataloader

In [9]:
#part a
#train_csv_path = "/kaggle/input/converting-handwritten-equations-to-latex-code/col_774_A4_2023/SyntheticData/train.csv"
#image_root_path = "/kaggle/input/converting-handwritten-equations-to-latex-code/col_774_A4_2023/SyntheticData/images"
train_csv_path = "data/SyntheticData/train.csv"
image_root_path = "data/SyntheticData/images"
train_dataloader, train_dataset = get_dataloader(train_csv_path, image_root_path, batch_size, transform, max_examples=None)

Loading Dataset...
Loaded Dataset <bound method DataFrame.info of                 image                                            formula
0      74d337e8a0.png  $ \gamma _ { \Omega R , 5 } ^ { T } = - \gamma...
1      2d0f18f71d.png  $ l ^ { ( -- ) \underline { { m } } } u _ { \u...
2      6d9b9de88d.png  $ \left[ H , \gamma _ { i } ^ { \left( 2 \righ...
3      38c6d510bb.png  $ < a _ { i } > \; \propto \; \int _ { \omega ...
4      24537a86e3.png  $ \Psi ( \mu _ { 1 } , \ldots , \mu _ { K } ) ...
...               ...                                                ...
74995  1fa37e67d2.png  $ T _ { \theta } ^ { \theta } = - \frac { 1 } ...
74996  75518a26df.png  $ \alpha _ { + } = - 1 / \alpha _ { - } = \sqr...
74997  29f28cbc3a.png  $ d s ^ { 2 } = Z ^ { - 1 / 2 } \eta _ { \mu \...
74998  33ac7b385d.png  $ \tilde { H } _ { 0 } = \frac { 1 } { 2 } ( \...
74999  52672fbf76.png  $ \psi _ { \alpha \beta } = - g _ { \alpha \ga...

[75000 rows x 2 columns]>


### Create Model

In [10]:
#create a network instance
encoder = EncoderCNN().to(device)
decoder = Decoder(CONTEXT_SIZE, train_dataset.vocab, n_layers=1, hidden_size= HIDDEN_SIZE, embed_size=EMBED_SIZE,max_length=train_dataset.maxlen).to(device)
print(torch.cuda.get_device_name(0))

NVIDIA A100-PCIE-40GB


### Train

In [None]:
train(train_dataloader, encoder, decoder, 10, save_interval = 2)

  0%|                                                                                             | 0/10 [00:00<?, ?it/s]

----Batch 1----
Loss for batch 1 = 6.336750030517578
----Batch 2----
Loss for batch 2 = 4.088656902313232
----Batch 3----
Loss for batch 3 = 1.8909673690795898
----Batch 4----
Loss for batch 4 = 0.934638261795044
----Batch 5----
Loss for batch 5 = 0.759469211101532
----Batch 6----
Loss for batch 6 = 0.8507074117660522
----Batch 7----
Loss for batch 7 = 0.7870851159095764
----Batch 8----
Loss for batch 8 = 0.7733004093170166
----Batch 9----
Loss for batch 9 = 0.7697874903678894
----Batch 10----
Loss for batch 10 = 0.7144590020179749
----Batch 11----
Loss for batch 11 = 0.7257136106491089
----Batch 12----
Loss for batch 12 = 0.7066599726676941
----Batch 13----
Loss for batch 13 = 0.7501876950263977
----Batch 14----
Loss for batch 14 = 0.7361814975738525
----Batch 15----
Loss for batch 15 = 0.6961299777030945
----Batch 16----
Loss for batch 16 = 0.7556687593460083
----Batch 17----
Loss for batch 17 = 0.6457914113998413
----Batch 18----
Loss for batch 18 = 0.7300599813461304
----Batch 19--

----Batch 147----
Loss for batch 147 = 0.6744541525840759
----Batch 148----
Loss for batch 148 = 0.6650529503822327
----Batch 149----
Loss for batch 149 = 0.6994333863258362
----Batch 150----
Loss for batch 150 = 0.6529591679573059
----Batch 151----
Loss for batch 151 = 0.6387475728988647
----Batch 152----
Loss for batch 152 = 0.6664289832115173
----Batch 153----
Loss for batch 153 = 0.7768464684486389
----Batch 154----
Loss for batch 154 = 0.7318217158317566
----Batch 155----
Loss for batch 155 = 0.7109777331352234
----Batch 156----
Loss for batch 156 = 0.7024853229522705
----Batch 157----
Loss for batch 157 = 0.6559823751449585
----Batch 158----
Loss for batch 158 = 0.6412087678909302
----Batch 159----
Loss for batch 159 = 0.8097293376922607
----Batch 160----
Loss for batch 160 = 0.7927542328834534
----Batch 161----
Loss for batch 161 = 0.7064275145530701
----Batch 162----
Loss for batch 162 = 0.7029843926429749
----Batch 163----
Loss for batch 163 = 0.6747865676879883
----Batch 164-

----Batch 289----
Loss for batch 289 = 0.7041324973106384
----Batch 290----
Loss for batch 290 = 0.7091871500015259
----Batch 291----
Loss for batch 291 = 0.775210976600647
----Batch 292----
Loss for batch 292 = 0.7401321530342102
----Batch 293----
Loss for batch 293 = 0.6901451945304871
----Batch 294----
Loss for batch 294 = 0.6835574507713318
----Batch 295----
Loss for batch 295 = 0.6795048117637634
----Batch 296----
Loss for batch 296 = 0.7966568470001221
----Batch 297----
Loss for batch 297 = 0.7475168108940125
----Batch 298----
Loss for batch 298 = 0.686545729637146
----Batch 299----
Loss for batch 299 = 0.6913046836853027
----Batch 300----
Loss for batch 300 = 0.6533486247062683
----Batch 301----
Loss for batch 301 = 0.6973468661308289
----Batch 302----
Loss for batch 302 = 0.7530636191368103
----Batch 303----
Loss for batch 303 = 0.6851115822792053
----Batch 304----
Loss for batch 304 = 0.7614176273345947
----Batch 305----
Loss for batch 305 = 0.7336505055427551
----Batch 306---

----Batch 431----
Loss for batch 431 = 0.6351765394210815
----Batch 432----
Loss for batch 432 = 0.7366238832473755
----Batch 433----
Loss for batch 433 = 0.7277689576148987
----Batch 434----
Loss for batch 434 = 0.8096972703933716
----Batch 435----
Loss for batch 435 = 0.8596280813217163
----Batch 436----
Loss for batch 436 = 0.6764630675315857
----Batch 437----
Loss for batch 437 = 0.788262665271759
----Batch 438----
Loss for batch 438 = 0.7149176597595215
----Batch 439----
Loss for batch 439 = 0.7677329778671265
----Batch 440----
Loss for batch 440 = 0.7264788150787354
----Batch 441----
Loss for batch 441 = 0.6461790800094604
----Batch 442----
Loss for batch 442 = 0.7471745014190674
----Batch 443----
Loss for batch 443 = 0.7360185980796814
----Batch 444----
Loss for batch 444 = 0.5982176661491394
----Batch 445----
Loss for batch 445 = 0.7458980679512024
----Batch 446----
Loss for batch 446 = 0.7365542054176331
----Batch 447----
Loss for batch 447 = 0.618071973323822
----Batch 448---

----Batch 573----
Loss for batch 573 = 0.7218344807624817
----Batch 574----
Loss for batch 574 = 0.694145917892456
----Batch 575----
Loss for batch 575 = 0.7493459582328796
----Batch 576----
Loss for batch 576 = 0.7099735736846924
----Batch 577----
Loss for batch 577 = 0.6904744505882263
----Batch 578----
Loss for batch 578 = 0.6587796807289124
----Batch 579----
Loss for batch 579 = 0.6618769764900208
----Batch 580----
Loss for batch 580 = 0.6781450510025024
----Batch 581----
Loss for batch 581 = 0.6816129684448242
----Batch 582----
Loss for batch 582 = 0.7098134160041809
----Batch 583----
Loss for batch 583 = 0.6826135516166687
----Batch 584----
Loss for batch 584 = 0.7491320371627808
----Batch 585----
Loss for batch 585 = 0.6897211074829102
----Batch 586----
Loss for batch 586 = 0.7988650798797607
----Batch 587----
Loss for batch 587 = 0.704430878162384
----Batch 588----
Loss for batch 588 = 0.7161509990692139
----Batch 589----
Loss for batch 589 = 0.7584951519966125
----Batch 590---

----Batch 715----
Loss for batch 715 = 0.7544010281562805
----Batch 716----
Loss for batch 716 = 0.6641505360603333
----Batch 717----
Loss for batch 717 = 0.6534276604652405
----Batch 718----
Loss for batch 718 = 0.7085065245628357
----Batch 719----
Loss for batch 719 = 0.7080752849578857
----Batch 720----
Loss for batch 720 = 0.5810132026672363
----Batch 721----
Loss for batch 721 = 0.7202900648117065
----Batch 722----
Loss for batch 722 = 0.7107740640640259
----Batch 723----
Loss for batch 723 = 0.6328524947166443
----Batch 724----
Loss for batch 724 = 0.6054275035858154
----Batch 725----
Loss for batch 725 = 0.752530038356781
----Batch 726----
Loss for batch 726 = 0.6331641674041748
----Batch 727----
Loss for batch 727 = 0.6830844879150391
----Batch 728----
Loss for batch 728 = 0.7748565673828125
----Batch 729----
Loss for batch 729 = 0.6659367680549622
----Batch 730----
Loss for batch 730 = 0.6996291875839233
----Batch 731----
Loss for batch 731 = 0.6372979283332825
----Batch 732--

----Batch 857----
Loss for batch 857 = 0.7205944657325745
----Batch 858----
Loss for batch 858 = 0.6488192081451416
----Batch 859----
Loss for batch 859 = 0.6760633587837219
----Batch 860----
Loss for batch 860 = 0.6927458643913269
----Batch 861----
Loss for batch 861 = 0.7610712647438049
----Batch 862----
Loss for batch 862 = 0.6194500923156738
----Batch 863----
Loss for batch 863 = 0.7328774333000183
----Batch 864----
Loss for batch 864 = 0.7241413593292236
----Batch 865----
Loss for batch 865 = 0.6450161337852478
----Batch 866----
Loss for batch 866 = 0.702581524848938
----Batch 867----
Loss for batch 867 = 0.5972874760627747
----Batch 868----
Loss for batch 868 = 0.6744013428688049
----Batch 869----
Loss for batch 869 = 0.7423197627067566
----Batch 870----
Loss for batch 870 = 0.6614496111869812
----Batch 871----
Loss for batch 871 = 0.7571935057640076
----Batch 872----
Loss for batch 872 = 0.7411447763442993
----Batch 873----
Loss for batch 873 = 0.7161589860916138
----Batch 874--

----Batch 999----
Loss for batch 999 = 0.6022792458534241
----Batch 1000----
Loss for batch 1000 = 0.6435367465019226
----Batch 1001----
Loss for batch 1001 = 0.7226988673210144
----Batch 1002----
Loss for batch 1002 = 0.6489488482475281
----Batch 1003----
Loss for batch 1003 = 0.7144573330879211
----Batch 1004----
Loss for batch 1004 = 0.627009391784668
----Batch 1005----
Loss for batch 1005 = 0.7423622012138367
----Batch 1006----
Loss for batch 1006 = 0.6757352948188782
----Batch 1007----
Loss for batch 1007 = 0.67897629737854
----Batch 1008----
Loss for batch 1008 = 0.5882512331008911
----Batch 1009----
Loss for batch 1009 = 0.6921077370643616
----Batch 1010----
Loss for batch 1010 = 0.7013546824455261
----Batch 1011----
Loss for batch 1011 = 0.7164918780326843
----Batch 1012----
Loss for batch 1012 = 0.649482011795044
----Batch 1013----
Loss for batch 1013 = 0.6759122014045715
----Batch 1014----
Loss for batch 1014 = 0.730982780456543
----Batch 1015----
Loss for batch 1015 = 0.7300

----Batch 1136----
Loss for batch 1136 = 0.6857531666755676
----Batch 1137----
Loss for batch 1137 = 0.66587233543396
----Batch 1138----
Loss for batch 1138 = 0.7145202159881592
----Batch 1139----
Loss for batch 1139 = 0.7068153023719788
----Batch 1140----
Loss for batch 1140 = 0.6865448355674744
----Batch 1141----
Loss for batch 1141 = 0.6164995431900024
----Batch 1142----
Loss for batch 1142 = 0.6816274523735046
----Batch 1143----
Loss for batch 1143 = 0.735662579536438
----Batch 1144----
Loss for batch 1144 = 0.688392162322998
----Batch 1145----
Loss for batch 1145 = 0.6516090631484985
----Batch 1146----
Loss for batch 1146 = 0.7455224394798279
----Batch 1147----
Loss for batch 1147 = 0.6867757439613342
----Batch 1148----
Loss for batch 1148 = 0.7364245653152466
----Batch 1149----
Loss for batch 1149 = 0.664126455783844
----Batch 1150----
Loss for batch 1150 = 0.6322470307350159
----Batch 1151----
Loss for batch 1151 = 0.7567314505577087
----Batch 1152----
Loss for batch 1152 = 0.60

----Batch 1273----
Loss for batch 1273 = 0.617983877658844
----Batch 1274----
Loss for batch 1274 = 0.7186917662620544
----Batch 1275----
Loss for batch 1275 = 0.6318919062614441
----Batch 1276----
Loss for batch 1276 = 0.7129671573638916
----Batch 1277----
Loss for batch 1277 = 0.7189098000526428
----Batch 1278----
Loss for batch 1278 = 0.7326858043670654
----Batch 1279----
Loss for batch 1279 = 0.7792018055915833
----Batch 1280----
Loss for batch 1280 = 0.6662009358406067
----Batch 1281----
Loss for batch 1281 = 0.6544719338417053
----Batch 1282----
Loss for batch 1282 = 0.7607899308204651
----Batch 1283----
Loss for batch 1283 = 0.6901742219924927
----Batch 1284----
Loss for batch 1284 = 0.7789520621299744
----Batch 1285----
Loss for batch 1285 = 0.6239725947380066
----Batch 1286----
Loss for batch 1286 = 0.7059160470962524
----Batch 1287----
Loss for batch 1287 = 0.710260272026062
----Batch 1288----
Loss for batch 1288 = 0.7297325730323792
----Batch 1289----
Loss for batch 1289 = 0

----Batch 1410----
Loss for batch 1410 = 0.7140955924987793
----Batch 1411----
Loss for batch 1411 = 0.6537958979606628
----Batch 1412----
Loss for batch 1412 = 0.7151766419410706
----Batch 1413----
Loss for batch 1413 = 0.7229346632957458
----Batch 1414----
Loss for batch 1414 = 0.6964122653007507
----Batch 1415----
Loss for batch 1415 = 0.7406502962112427
----Batch 1416----
Loss for batch 1416 = 0.7143641114234924
----Batch 1417----
Loss for batch 1417 = 0.7778509259223938
----Batch 1418----
Loss for batch 1418 = 0.7073612213134766
----Batch 1419----
Loss for batch 1419 = 0.6869283318519592
----Batch 1420----
Loss for batch 1420 = 0.6848787069320679
----Batch 1421----
Loss for batch 1421 = 0.6636192202568054
----Batch 1422----
Loss for batch 1422 = 0.757472038269043
----Batch 1423----
Loss for batch 1423 = 0.6562491059303284
----Batch 1424----
Loss for batch 1424 = 0.6833276152610779
----Batch 1425----
Loss for batch 1425 = 0.7115554213523865
----Batch 1426----
Loss for batch 1426 = 

----Batch 1547----
Loss for batch 1547 = 0.8386951684951782
----Batch 1548----
Loss for batch 1548 = 0.6784811019897461
----Batch 1549----
Loss for batch 1549 = 0.6983582377433777
----Batch 1550----
Loss for batch 1550 = 0.6276189088821411
----Batch 1551----
Loss for batch 1551 = 0.7279150485992432
----Batch 1552----
Loss for batch 1552 = 0.7215750813484192
----Batch 1553----
Loss for batch 1553 = 0.7135194540023804
----Batch 1554----
Loss for batch 1554 = 0.636299729347229
----Batch 1555----
Loss for batch 1555 = 0.6211175918579102
----Batch 1556----
Loss for batch 1556 = 0.754625678062439
----Batch 1557----
Loss for batch 1557 = 0.7213339805603027
----Batch 1558----
Loss for batch 1558 = 0.7150394916534424
----Batch 1559----
Loss for batch 1559 = 0.6567347645759583
----Batch 1560----
Loss for batch 1560 = 0.699180006980896
----Batch 1561----
Loss for batch 1561 = 0.7310092449188232
----Batch 1562----
Loss for batch 1562 = 0.6585364937782288
----Batch 1563----
Loss for batch 1563 = 0.

----Batch 1684----
Loss for batch 1684 = 0.7337121367454529
----Batch 1685----
Loss for batch 1685 = 0.7122071981430054
----Batch 1686----
Loss for batch 1686 = 0.736432671546936
----Batch 1687----
Loss for batch 1687 = 0.6665481328964233
----Batch 1688----
Loss for batch 1688 = 0.6961634755134583
----Batch 1689----
Loss for batch 1689 = 0.6716952323913574
----Batch 1690----
Loss for batch 1690 = 0.6552382707595825
----Batch 1691----
Loss for batch 1691 = 0.7194363474845886
----Batch 1692----
Loss for batch 1692 = 0.7569671273231506
----Batch 1693----
Loss for batch 1693 = 0.6560515761375427
----Batch 1694----
Loss for batch 1694 = 0.7083282470703125
----Batch 1695----
Loss for batch 1695 = 0.7202997803688049
----Batch 1696----
Loss for batch 1696 = 0.6728671789169312
----Batch 1697----
Loss for batch 1697 = 0.6917527914047241
----Batch 1698----
Loss for batch 1698 = 0.7351111173629761
----Batch 1699----
Loss for batch 1699 = 0.6928327679634094
----Batch 1700----
Loss for batch 1700 = 

----Batch 1821----
Loss for batch 1821 = 0.7837533354759216
----Batch 1822----
Loss for batch 1822 = 0.6718737483024597
----Batch 1823----
Loss for batch 1823 = 0.7298814654350281
----Batch 1824----
Loss for batch 1824 = 0.6961553692817688
----Batch 1825----
Loss for batch 1825 = 0.723446786403656
----Batch 1826----
Loss for batch 1826 = 0.7648847699165344
----Batch 1827----
Loss for batch 1827 = 0.7078773379325867
----Batch 1828----
Loss for batch 1828 = 0.7109223008155823
----Batch 1829----
Loss for batch 1829 = 0.7388740181922913
----Batch 1830----
Loss for batch 1830 = 0.7188839316368103
----Batch 1831----
Loss for batch 1831 = 0.7271689772605896
----Batch 1832----
Loss for batch 1832 = 0.7944285869598389
----Batch 1833----
Loss for batch 1833 = 0.6756759285926819
----Batch 1834----
Loss for batch 1834 = 0.6513779163360596
----Batch 1835----
Loss for batch 1835 = 0.6926438212394714
----Batch 1836----
Loss for batch 1836 = 0.601710319519043
----Batch 1837----
Loss for batch 1837 = 0

----Batch 1958----
Loss for batch 1958 = 0.8067656755447388
----Batch 1959----
Loss for batch 1959 = 0.6641156077384949
----Batch 1960----
Loss for batch 1960 = 0.718574583530426
----Batch 1961----
Loss for batch 1961 = 0.683213472366333
----Batch 1962----
Loss for batch 1962 = 0.6342877745628357
----Batch 1963----
Loss for batch 1963 = 0.729285717010498
----Batch 1964----
Loss for batch 1964 = 0.772255539894104
----Batch 1965----
Loss for batch 1965 = 0.6652154922485352
----Batch 1966----
Loss for batch 1966 = 0.677757978439331
----Batch 1967----
Loss for batch 1967 = 0.7300474047660828
----Batch 1968----
Loss for batch 1968 = 0.6371285915374756
----Batch 1969----
Loss for batch 1969 = 0.7281906604766846
----Batch 1970----
Loss for batch 1970 = 0.6692319512367249
----Batch 1971----
Loss for batch 1971 = 0.6659696102142334
----Batch 1972----
Loss for batch 1972 = 0.7007879614830017
----Batch 1973----
Loss for batch 1973 = 0.7260075807571411
----Batch 1974----
Loss for batch 1974 = 0.71

----Batch 2095----
Loss for batch 2095 = 0.7146406769752502
----Batch 2096----
Loss for batch 2096 = 0.7647408843040466
----Batch 2097----
Loss for batch 2097 = 0.6898825168609619
----Batch 2098----
Loss for batch 2098 = 0.7206010222434998
----Batch 2099----
Loss for batch 2099 = 0.734613299369812
----Batch 2100----
Loss for batch 2100 = 0.7157195210456848
----Batch 2101----
Loss for batch 2101 = 0.7192611694335938
----Batch 2102----
Loss for batch 2102 = 0.6728798747062683
----Batch 2103----
Loss for batch 2103 = 0.7441477179527283
----Batch 2104----
Loss for batch 2104 = 0.6684219837188721
----Batch 2105----
Loss for batch 2105 = 0.6556771397590637
----Batch 2106----
Loss for batch 2106 = 0.6976993083953857
----Batch 2107----
Loss for batch 2107 = 0.8265498280525208
----Batch 2108----
Loss for batch 2108 = 0.7433167695999146
----Batch 2109----
Loss for batch 2109 = 0.6914535164833069
----Batch 2110----
Loss for batch 2110 = 0.5937029719352722
----Batch 2111----
Loss for batch 2111 = 

----Batch 2232----
Loss for batch 2232 = 0.7388573288917542
----Batch 2233----
Loss for batch 2233 = 0.7137944102287292
----Batch 2234----
Loss for batch 2234 = 0.7908998727798462
----Batch 2235----
Loss for batch 2235 = 0.7194564938545227
----Batch 2236----
Loss for batch 2236 = 0.7061835527420044
----Batch 2237----
Loss for batch 2237 = 0.6687461137771606
----Batch 2238----
Loss for batch 2238 = 0.7138782143592834
----Batch 2239----
Loss for batch 2239 = 0.7120891809463501
----Batch 2240----
Loss for batch 2240 = 0.7366477847099304
----Batch 2241----
Loss for batch 2241 = 0.7360848188400269
----Batch 2242----
Loss for batch 2242 = 0.6258131265640259
----Batch 2243----
Loss for batch 2243 = 0.6629133820533752
----Batch 2244----
Loss for batch 2244 = 0.7806708216667175
----Batch 2245----
Loss for batch 2245 = 0.616030752658844
----Batch 2246----
Loss for batch 2246 = 0.5925102829933167
----Batch 2247----
Loss for batch 2247 = 0.7496734857559204
----Batch 2248----
Loss for batch 2248 = 

32m 48s (- 295m 17s) (1 10%) 0.7076:  10%|████▍                                       | 1/10 [32:48<4:55:17, 1968.61s/it]

Loss for batch 2344 = 0.7824202179908752
----Batch 1----
Loss for batch 1 = 0.6631912589073181
----Batch 2----
Loss for batch 2 = 0.6759697198867798
----Batch 3----
Loss for batch 3 = 0.7161123156547546
----Batch 4----
Loss for batch 4 = 0.7074294090270996
----Batch 5----
Loss for batch 5 = 0.7235667109489441
----Batch 6----
Loss for batch 6 = 0.7286194562911987
----Batch 7----
Loss for batch 7 = 0.6664547920227051
----Batch 8----
Loss for batch 8 = 0.6477235555648804
----Batch 9----
Loss for batch 9 = 0.7044108510017395
----Batch 10----
Loss for batch 10 = 0.7561968564987183
----Batch 11----
Loss for batch 11 = 0.7242017388343811
----Batch 12----
Loss for batch 12 = 0.6615690588951111
----Batch 13----
Loss for batch 13 = 0.7198619246482849
----Batch 14----
Loss for batch 14 = 0.7349472045898438
----Batch 15----
Loss for batch 15 = 0.7040112018585205
----Batch 16----
Loss for batch 16 = 0.7117294669151306
----Batch 17----
Loss for batch 17 = 0.6493587493896484
----Batch 18----
Loss for

Loss for batch 146 = 0.6538780927658081
----Batch 147----
Loss for batch 147 = 0.76242995262146
----Batch 148----
Loss for batch 148 = 0.6612457036972046
----Batch 149----
Loss for batch 149 = 0.662057638168335
----Batch 150----
Loss for batch 150 = 0.7472861409187317
----Batch 151----
Loss for batch 151 = 0.7550366520881653
----Batch 152----
Loss for batch 152 = 0.6840179562568665
----Batch 153----
Loss for batch 153 = 0.6201678514480591
----Batch 154----
Loss for batch 154 = 0.6488479375839233
----Batch 155----
Loss for batch 155 = 0.7210575342178345
----Batch 156----
Loss for batch 156 = 0.6416693925857544
----Batch 157----
Loss for batch 157 = 0.6642118096351624
----Batch 158----
Loss for batch 158 = 0.6829869747161865
----Batch 159----
Loss for batch 159 = 0.6531922817230225
----Batch 160----
Loss for batch 160 = 0.7419264316558838
----Batch 161----
Loss for batch 161 = 0.6603551506996155
----Batch 162----
Loss for batch 162 = 0.7270459532737732
----Batch 163----
Loss for batch 16

Loss for batch 288 = 0.7981297373771667
----Batch 289----
Loss for batch 289 = 0.7043896913528442
----Batch 290----
Loss for batch 290 = 0.7754504680633545
----Batch 291----
Loss for batch 291 = 0.7603468894958496
----Batch 292----
Loss for batch 292 = 0.7414373159408569
----Batch 293----
Loss for batch 293 = 0.7729094624519348
----Batch 294----
Loss for batch 294 = 0.709554135799408
----Batch 295----
Loss for batch 295 = 0.6676201224327087
----Batch 296----
Loss for batch 296 = 0.7248518466949463
----Batch 297----
Loss for batch 297 = 0.708402693271637
----Batch 298----
Loss for batch 298 = 0.693486750125885
----Batch 299----
Loss for batch 299 = 0.650813639163971
----Batch 300----
Loss for batch 300 = 0.751993715763092
----Batch 301----
Loss for batch 301 = 0.6769618988037109
----Batch 302----
Loss for batch 302 = 0.7759156823158264
----Batch 303----
Loss for batch 303 = 0.6885668635368347
----Batch 304----
Loss for batch 304 = 0.7498618364334106
----Batch 305----
Loss for batch 305 

Loss for batch 430 = 0.7392565608024597
----Batch 431----
Loss for batch 431 = 0.850979208946228
----Batch 432----
Loss for batch 432 = 0.6724632382392883
----Batch 433----
Loss for batch 433 = 0.7479549646377563
----Batch 434----
Loss for batch 434 = 0.6827015280723572
----Batch 435----
Loss for batch 435 = 0.7441122531890869
----Batch 436----
Loss for batch 436 = 0.6397472023963928
----Batch 437----
Loss for batch 437 = 0.7018383741378784
----Batch 438----
Loss for batch 438 = 0.6135950088500977
----Batch 439----
Loss for batch 439 = 0.6915257573127747
----Batch 440----
Loss for batch 440 = 0.7275986671447754
----Batch 441----
