# Data Setup

## Load and Split Data

In [46]:
#Based on code from fastai
import glob
CLASS_VALS = ['neg', 'pos']

def load_text(filepath):
    examples = []
    labels = []
    for index, class_str in enumerate(CLASS_VALS):
        path = filepath + '/' + class_str
        for filename in glob.glob(path + '/*.txt'):
            text = open(filename, 'r').read()
            examples.append(text)
            labels.append(index)
    return examples, labels

data_path = 'aclImdb/'
tv_data, tv_labels = load_text(data_path + 'train')
test_data, test_labels = load_text(data_path + 'test')

In [47]:
from sklearn.model_selection import train_test_split
train_data, val_data, train_labels, val_labels = train_test_split(tv_data, tv_labels, test_size = 0.2)

## Tokenization

In [48]:
#Lab Tokenization

import spacy
import string
from tqdm import tqdm_notebook


# Load English tokenizer, tagger, parser, NER and word vectors
tokenizer = spacy.load('en_core_web_sm')
punctuations = string.punctuation

def lower_case_remove_punc(parsed):
    return [token.text.lower() for token in parsed if (token.text not in punctuations)]

def tokenize_dataset(dataset):
    token_dataset = []
    # we are keeping track of all tokens in dataset
    # in order to create vocabulary later
    all_tokens = []

    for sample in tqdm_notebook(tokenizer.pipe(dataset, disable=['parser', 'tagger', 'ner'], batch_size=512, n_threads=1)):
        tokens = lower_case_remove_punc(sample)
        token_dataset.append(tokens)
        all_tokens += tokens

    return token_dataset, all_tokens


In [49]:
#Lab Tokenization
import pickle as pkl

# val set tokens
print ("Tokenizing val data")
val_data_tokens, _ = tokenize_dataset(val_data)
pkl.dump(val_data_tokens, open("val_data_tokens.p", "wb"))

# test set tokens
print ("Tokenizing test data")
test_data_tokens, _ = tokenize_dataset(test_data)
pkl.dump(test_data_tokens, open("test_data_tokens.p", "wb"))

# train set tokens
print ("Tokenizing train data")
train_data_tokens, all_train_tokens = tokenize_dataset(train_data)
pkl.dump(train_data_tokens, open("train_data_tokens.p", "wb"))
pkl.dump(all_train_tokens, open("all_train_tokens.p", "wb"))

Tokenizing val data


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Tokenizing test data


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Tokenizing train data


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [50]:
#Lab Vocab Builder
from collections import Counter

max_vocab_size = 10000
# save index 0 for unk and 1 for pad
PAD_IDX = 0
UNK_IDX = 1

def build_vocab(all_tokens):
    # Returns:
    # id2token: list of tokens, where id2token[i] returns token that corresponds to token i
    # token2id: dictionary where keys represent tokens and corresponding values represent indices
    token_counter = Counter(all_tokens)
    vocab, count = zip(*token_counter.most_common(max_vocab_size))
    id2token = list(vocab)
    token2id = dict(zip(vocab, range(2,2+len(vocab)))) 
    id2token = ['<pad>', '<unk>'] + id2token
    token2id['<pad>'] = PAD_IDX 
    token2id['<unk>'] = UNK_IDX
    return token2id, id2token

token2id, id2token = build_vocab(all_train_tokens)

In [51]:
# lab tok2id
def token2index_dataset(tokens_data):
    indices_data = []
    for tokens in tokens_data:
        index_list = [token2id[token] if token in token2id else UNK_IDX for token in tokens]
        indices_data.append(index_list)
    return indices_data

train_data_indices = token2index_dataset(train_data_tokens)
val_data_indices = token2index_dataset(val_data_tokens)
test_data_indices = token2index_dataset(test_data_tokens)

# double checking
print ("Train dataset size is {}".format(len(train_data_indices)))
print ("Val dataset size is {}".format(len(val_data_indices)))
print ("Test dataset size is {}".format(len(test_data_indices)))

Train dataset size is 20000
Val dataset size is 5000
Test dataset size is 25000


In [52]:
#Check max length in train
sen_lens = []
for i in range(20000):
    sen_lens.append(len(train_data_indices[i]))
print(max(sen_lens))

2521


# Model

In [123]:
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
Using BOW model from lab. Output logits for numerical stability. Because this is binary classification,
we could also have used BCEWithLogits and used sigmoid for predictions instead, but I encountered dimension
issues trying to do this w/ tensors.  The current softmax/crossentropy is equivalent.
"""
class CBOW(nn.Module):
    
    def __init__(self, vocab_size, emb_dim):
        super(CBOW, self).__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.linear = nn.Linear(emb_dim,2)
        
    def forward(self, data, length):
        out = self.embed(data)
        out = torch.sum(out, dim=1)
        out /= length.view(length.size()[0],1).expand_as(out).float()

        # return logits
        out = self.linear(out.float())
        return out
    

## Data Loader

In [124]:
from torch.utils.data import Dataset
import numpy as np

MAX_SENTENCE = 3000

class ImdbDataset(Dataset):
    """
    Same as lab3 dataset
    """
    
    def __init__(self, data_list, target_list):
        """
        @param data_list: list of newsgroup tokens 
        @param target_list: list of newsgroup targets 

        """
        self.data_list = data_list
        self.target_list = target_list
        assert (len(self.data_list) == len(self.target_list))

    def __len__(self):
        return len(self.data_list)
        
    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        
        token_idx = self.data_list[key][:MAX_SENTENCE]
        label = self.target_list[key]
        return [token_idx, len(token_idx), label]
    
def pad_and_collate_func(batch):
    """
    From lab. Tried using nn.utils.rnn.pad_sequence but that does not match the right datatypes
    """
    
    label_list = []
    length_list = []
    padded_data = []
    
    
    for datum in batch:
        label_list.append(datum[2])
        length_list.append(datum[1])
        padded_vec = np.pad(np.array(datum[0]), 
                            pad_width=((0,MAX_SENTENCE-datum[1])), 
                            mode="constant", constant_values=0)
        padded_data.append(padded_vec)

    
    return [torch.from_numpy(np.array(padded_data)), torch.LongTensor(length_list), torch.LongTensor(label_list)]



In [125]:
BATCH_SIZE = 32
train_dataset = ImdbDataset(train_data_indices, train_labels)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=pad_and_collate_func,
                                           shuffle=True)

val_dataset = ImdbDataset(val_data_indices, val_labels)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=pad_and_collate_func,
                                           shuffle=True)

test_dataset = ImdbDataset(test_data_indices, test_labels)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=pad_and_collate_func,
                                           shuffle=False)


## Train Model

In [126]:
e_dim = 100
v_size = len(token2id)
model = CBOW(v_size, e_dim)
learning_rate = 0.01
num_epochs = 10 # number epoch to train

# Criterion and Optimizer
criterion = torch.nn.CrossEntropyLoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [127]:
# Function for testing the model
def test_model(loader, model):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    total = 0
    model.eval()
    for data, lengths, labels in loader:
        data_batch, length_batch, label_batch = data, lengths, labels
        outputs = F.softmax(model(data_batch, length_batch), dim = 1)
        predicted = outputs.max(1,keepdim=True)[1]    
        
        total += labels.size(0)
        correct += predicted.eq(labels.view_as(predicted)).sum().item()
    return (100 * correct / total)

def train_model()

for epoch in range(num_epochs):
    for i, (data, lengths, labels) in enumerate(train_loader):
        model.train()
        data_batch, length_batch, label_batch = data, lengths, labels
        optimizer.zero_grad()
        outputs = model(data_batch, length_batch)
        loss = criterion(outputs, label_batch)
        loss.backward()
        optimizer.step()
        # validate every 100 iterations
        if i > 0 and i % 100 == 0:
            # validate
            val_acc = test_model(val_loader, model)
            print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format( 
                       epoch+1, num_epochs, i+1, len(train_loader), val_acc))


torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])


torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])


torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
Epoch: [2/10], Step: [201/625], Validation Acc: 86.96
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.

torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
Epoch: [2/10], Step: [601/625], Validation Acc: 87.18
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.

torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
Epoch: [3/10], Step: [401/625], Validation Acc: 87.24
torch.Size([32, 2])
torch.Size([32, 2])
torch.

torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])


torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])


torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])


torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
Epoch: [6/10], Step: [101/625], Validation Acc: 86.6
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.S

torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
Epoch: [6/10], Step: [501/625], Validation Acc: 85.72
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.

torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
Epoch: [7/10], Step: [301/625], Validation Acc: 85.9
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.S

torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])


torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])


torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])


torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])


torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
Epoch: [10/10], Step: [401/625], Validation Acc: 85.36
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch.Size([32, 2])
torch