In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pandas as pd

In [2]:
df = pd.read_csv('champions.csv')
def get_champ_vec(champ_name):
    arr = []
    row = df.loc[df['Champion'] == champ_name].iloc[0]
    for i in range(1, 16):
        if i == 7:
            if row[i] == 1:
                arr.append(1)
                arr.append(0)
                arr.append(0)
            elif row[i] == 2:
                arr.append(0)
                arr.append(1)
                arr.append(0)
            elif row[i] == 3:
                arr.append(0)
                arr.append(0)
                arr.append(1)
        else:
            arr.append(row[i] / 10)
    return arr

def process_match(match):
        champs = match.split(',')
        vec = []
        for i in range(10):
            champ_vec = get_champ_vec(champs[i])
            for e in champ_vec:
                vec.append(e)
        if champs[10] == 'true':
            vec.append(1)
        else:
            vec.append(0)
        return vec

class MatchDataset(Dataset):
    def __init__(self, file_name, num_games, skip):
        self.inputs = []
        self.labels = []
        with open('data/matches/' + file_name, 'r') as f:
            matches = f.read().split('\n')
            for i in range(skip, skip+num_games):
                try:
                    match_vec = process_match(matches[i])
                    self.inputs.append(torch.tensor(match_vec[:-1], dtype=torch.float32))
                    self.labels.append(torch.tensor([match_vec[-1]], dtype=torch.float32))
                except:
                    pass
                    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

In [7]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(170, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits
model = NeuralNetwork()

In [4]:
learning_rate = 1e-4
batch_size = 64
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

training_data = MatchDataset('diamond_training_data.txt', 150000, 0)
testing_data = MatchDataset('diamond_training_data.txt', 10000, 150000)
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(testing_data, batch_size=batch_size, shuffle=True)

In [5]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    avgLoss = 0
    count = 0
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = torch.reshape(model(X), (X.shape[0],))
        loss = loss_fn(pred, torch.reshape(y, (y.shape[0],)))

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            avgLoss += loss.item()
            count += 1
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    return avgLoss / count

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = torch.reshape(model(X), (X.shape[0],))
            test_loss += loss_fn(pred, torch.reshape(y, (y.shape[0],))).item()
            for i in range(pred.shape[0]):
                guess = 0
                if pred[i] > 0.5:
                    guess = 1
                if guess == y[i]:
                    correct += 1

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss

In [8]:
epochs = 200
train_losses = []
test_losses = []
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loss = test_loop(test_dataloader, model, loss_fn)
    train_losses.append(train_loss)
    test_losses.append(test_losses)
print("Done!")

Epoch 1
-------------------------------
loss: 0.694327  [   64/149994]
loss: 0.704204  [ 6464/149994]
loss: 0.694567  [12864/149994]
loss: 0.703273  [19264/149994]
loss: 0.704242  [25664/149994]
loss: 0.691484  [32064/149994]
loss: 0.694522  [38464/149994]
loss: 0.698154  [44864/149994]
loss: 0.692793  [51264/149994]
loss: 0.696294  [57664/149994]
loss: 0.687717  [64064/149994]
loss: 0.696342  [70464/149994]
loss: 0.706071  [76864/149994]
loss: 0.689858  [83264/149994]
loss: 0.696292  [89664/149994]
loss: 0.692684  [96064/149994]
loss: 0.709491  [102464/149994]
loss: 0.688191  [108864/149994]
loss: 0.691322  [115264/149994]
loss: 0.686390  [121664/149994]
loss: 0.696200  [128064/149994]
loss: 0.683331  [134464/149994]
loss: 0.706800  [140864/149994]
loss: 0.714743  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694642 

Epoch 2
-------------------------------
loss: 0.714360  [   64/149994]
loss: 0.682828  [ 6464/149994]
loss: 0.704445  [12864/149994]
loss: 0.692934  [19264/

loss: 0.706471  [115264/149994]
loss: 0.689496  [121664/149994]
loss: 0.704524  [128064/149994]
loss: 0.694438  [134464/149994]
loss: 0.681173  [140864/149994]
loss: 0.696023  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694720 

Epoch 11
-------------------------------
loss: 0.696158  [   64/149994]
loss: 0.686172  [ 6464/149994]
loss: 0.687873  [12864/149994]
loss: 0.691360  [19264/149994]
loss: 0.697631  [25664/149994]
loss: 0.698085  [32064/149994]
loss: 0.696103  [38464/149994]
loss: 0.696368  [44864/149994]
loss: 0.688280  [51264/149994]
loss: 0.706206  [57664/149994]
loss: 0.698028  [64064/149994]
loss: 0.688100  [70464/149994]
loss: 0.707710  [76864/149994]
loss: 0.694534  [83264/149994]
loss: 0.679863  [89664/149994]
loss: 0.694614  [96064/149994]
loss: 0.701268  [102464/149994]
loss: 0.698076  [108864/149994]
loss: 0.708144  [115264/149994]
loss: 0.699807  [121664/149994]
loss: 0.691167  [128064/149994]
loss: 0.699253  [134464/149994]
loss: 0.699771  [140864/149

loss: 0.694767  [76864/149994]
loss: 0.696304  [83264/149994]
loss: 0.697814  [89664/149994]
loss: 0.689643  [96064/149994]
loss: 0.684534  [102464/149994]
loss: 0.697800  [108864/149994]
loss: 0.698334  [115264/149994]
loss: 0.696338  [121664/149994]
loss: 0.676102  [128064/149994]
loss: 0.696196  [134464/149994]
loss: 0.699648  [140864/149994]
loss: 0.699414  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694718 

Epoch 21
-------------------------------
loss: 0.696159  [   64/149994]
loss: 0.701312  [ 6464/149994]
loss: 0.699718  [12864/149994]
loss: 0.701218  [19264/149994]
loss: 0.689705  [25664/149994]
loss: 0.692931  [32064/149994]
loss: 0.689358  [38464/149994]
loss: 0.697736  [44864/149994]
loss: 0.700789  [51264/149994]
loss: 0.691082  [57664/149994]
loss: 0.697954  [64064/149994]
loss: 0.689591  [70464/149994]
loss: 0.693087  [76864/149994]
loss: 0.696118  [83264/149994]
loss: 0.701404  [89664/149994]
loss: 0.693405  [96064/149994]
loss: 0.694672  [102464/149994]

loss: 0.691256  [38464/149994]
loss: 0.695912  [44864/149994]
loss: 0.694344  [51264/149994]
loss: 0.686168  [57664/149994]
loss: 0.694814  [64064/149994]
loss: 0.706251  [70464/149994]
loss: 0.703134  [76864/149994]
loss: 0.697619  [83264/149994]
loss: 0.694357  [89664/149994]
loss: 0.707771  [96064/149994]
loss: 0.696092  [102464/149994]
loss: 0.689614  [108864/149994]
loss: 0.689780  [115264/149994]
loss: 0.691197  [121664/149994]
loss: 0.689408  [128064/149994]
loss: 0.696063  [134464/149994]
loss: 0.684713  [140864/149994]
loss: 0.691349  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694871 

Epoch 31
-------------------------------
loss: 0.693005  [   64/149994]
loss: 0.687718  [ 6464/149994]
loss: 0.688368  [12864/149994]
loss: 0.696074  [19264/149994]
loss: 0.702971  [25664/149994]
loss: 0.683177  [32064/149994]
loss: 0.689734  [38464/149994]
loss: 0.692587  [44864/149994]
loss: 0.689773  [51264/149994]
loss: 0.705940  [57664/149994]
loss: 0.694482  [64064/149994]


loss: 0.706387  [19264/149994]
loss: 0.686195  [25664/149994]
loss: 0.689577  [32064/149994]
loss: 0.697852  [38464/149994]
loss: 0.694699  [44864/149994]
loss: 0.688011  [51264/149994]
loss: 0.677841  [57664/149994]
loss: 0.689518  [64064/149994]
loss: 0.691094  [70464/149994]
loss: 0.691584  [76864/149994]
loss: 0.681422  [83264/149994]
loss: 0.699631  [89664/149994]
loss: 0.696397  [96064/149994]
loss: 0.687844  [102464/149994]
loss: 0.691103  [108864/149994]
loss: 0.701122  [115264/149994]
loss: 0.699399  [121664/149994]
loss: 0.694838  [128064/149994]
loss: 0.706221  [134464/149994]
loss: 0.693035  [140864/149994]
loss: 0.699636  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694678 

Epoch 41
-------------------------------
loss: 0.692884  [   64/149994]
loss: 0.693083  [ 6464/149994]
loss: 0.699347  [12864/149994]
loss: 0.706488  [19264/149994]
loss: 0.699472  [25664/149994]
loss: 0.689682  [32064/149994]
loss: 0.703201  [38464/149994]
loss: 0.691185  [44864/149994]


Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694832 

Epoch 50
-------------------------------
loss: 0.694708  [   64/149994]
loss: 0.696457  [ 6464/149994]
loss: 0.698154  [12864/149994]
loss: 0.694643  [19264/149994]
loss: 0.696462  [25664/149994]
loss: 0.696290  [32064/149994]
loss: 0.691192  [38464/149994]
loss: 0.694505  [44864/149994]
loss: 0.694695  [51264/149994]
loss: 0.684319  [57664/149994]
loss: 0.687905  [64064/149994]
loss: 0.687514  [70464/149994]
loss: 0.681523  [76864/149994]
loss: 0.681563  [83264/149994]
loss: 0.701009  [89664/149994]
loss: 0.689495  [96064/149994]
loss: 0.709459  [102464/149994]
loss: 0.689342  [108864/149994]
loss: 0.697684  [115264/149994]
loss: 0.703149  [121664/149994]
loss: 0.702787  [128064/149994]
loss: 0.691482  [134464/149994]
loss: 0.689736  [140864/149994]
loss: 0.701125  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694680 

Epoch 51
-------------------------------
loss: 0.701146  [   64/149994]
loss: 0.684792  [ 6464/149994]


loss: 0.679671  [115264/149994]
loss: 0.691034  [121664/149994]
loss: 0.712760  [128064/149994]
loss: 0.698073  [134464/149994]
loss: 0.692917  [140864/149994]
loss: 0.697644  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694793 

Epoch 60
-------------------------------
loss: 0.697935  [   64/149994]
loss: 0.692713  [ 6464/149994]
loss: 0.691197  [12864/149994]
loss: 0.702882  [19264/149994]
loss: 0.686046  [25664/149994]
loss: 0.694665  [32064/149994]
loss: 0.699671  [38464/149994]
loss: 0.692728  [44864/149994]
loss: 0.699642  [51264/149994]
loss: 0.701425  [57664/149994]
loss: 0.684488  [64064/149994]
loss: 0.692929  [70464/149994]
loss: 0.694577  [76864/149994]
loss: 0.692995  [83264/149994]
loss: 0.696243  [89664/149994]
loss: 0.694825  [96064/149994]
loss: 0.702847  [102464/149994]
loss: 0.692894  [108864/149994]
loss: 0.704281  [115264/149994]
loss: 0.696265  [121664/149994]
loss: 0.694350  [128064/149994]
loss: 0.699373  [134464/149994]
loss: 0.706350  [140864/149

loss: 0.698016  [76864/149994]
loss: 0.702722  [83264/149994]
loss: 0.686219  [89664/149994]
loss: 0.701196  [96064/149994]
loss: 0.696184  [102464/149994]
loss: 0.704406  [108864/149994]
loss: 0.698042  [115264/149994]
loss: 0.686198  [121664/149994]
loss: 0.689787  [128064/149994]
loss: 0.701528  [134464/149994]
loss: 0.699564  [140864/149994]
loss: 0.679717  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694681 

Epoch 70
-------------------------------
loss: 0.692903  [   64/149994]
loss: 0.696216  [ 6464/149994]
loss: 0.699523  [12864/149994]
loss: 0.691546  [19264/149994]
loss: 0.691336  [25664/149994]
loss: 0.694594  [32064/149994]
loss: 0.692904  [38464/149994]
loss: 0.687629  [44864/149994]
loss: 0.699446  [51264/149994]
loss: 0.704643  [57664/149994]
loss: 0.694649  [64064/149994]
loss: 0.703102  [70464/149994]
loss: 0.696169  [76864/149994]
loss: 0.693014  [83264/149994]
loss: 0.698207  [89664/149994]
loss: 0.684511  [96064/149994]
loss: 0.689659  [102464/149994]

loss: 0.701168  [38464/149994]
loss: 0.698098  [44864/149994]
loss: 0.701333  [51264/149994]
loss: 0.694593  [57664/149994]
loss: 0.689931  [64064/149994]
loss: 0.696392  [70464/149994]
loss: 0.696298  [76864/149994]
loss: 0.697818  [83264/149994]
loss: 0.710994  [89664/149994]
loss: 0.689533  [96064/149994]
loss: 0.699564  [102464/149994]
loss: 0.698062  [108864/149994]
loss: 0.697999  [115264/149994]
loss: 0.694531  [121664/149994]
loss: 0.692912  [128064/149994]
loss: 0.706086  [134464/149994]
loss: 0.693242  [140864/149994]
loss: 0.699504  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694756 

Epoch 80
-------------------------------
loss: 0.694186  [   64/149994]
loss: 0.692876  [ 6464/149994]
loss: 0.699502  [12864/149994]
loss: 0.693134  [19264/149994]
loss: 0.702873  [25664/149994]
loss: 0.701244  [32064/149994]
loss: 0.696120  [38464/149994]
loss: 0.702889  [44864/149994]
loss: 0.706262  [51264/149994]
loss: 0.698091  [57664/149994]
loss: 0.691223  [64064/149994]


loss: 0.694636  [19264/149994]
loss: 0.692627  [25664/149994]
loss: 0.692569  [32064/149994]
loss: 0.688036  [38464/149994]
loss: 0.701196  [44864/149994]
loss: 0.699603  [51264/149994]
loss: 0.699492  [57664/149994]
loss: 0.699270  [64064/149994]
loss: 0.706233  [70464/149994]
loss: 0.696341  [76864/149994]
loss: 0.689653  [83264/149994]
loss: 0.697812  [89664/149994]
loss: 0.701244  [96064/149994]
loss: 0.699472  [102464/149994]
loss: 0.693165  [108864/149994]
loss: 0.689674  [115264/149994]
loss: 0.687796  [121664/149994]
loss: 0.691143  [128064/149994]
loss: 0.687472  [134464/149994]
loss: 0.706067  [140864/149994]
loss: 0.707808  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694791 

Epoch 90
-------------------------------
loss: 0.687907  [   64/149994]
loss: 0.692763  [ 6464/149994]
loss: 0.697772  [12864/149994]
loss: 0.701143  [19264/149994]
loss: 0.702566  [25664/149994]
loss: 0.689839  [32064/149994]
loss: 0.687719  [38464/149994]
loss: 0.699678  [44864/149994]


Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694755 

Epoch 99
-------------------------------
loss: 0.704722  [   64/149994]
loss: 0.694440  [ 6464/149994]
loss: 0.701556  [12864/149994]
loss: 0.688077  [19264/149994]
loss: 0.691246  [25664/149994]
loss: 0.687661  [32064/149994]
loss: 0.698147  [38464/149994]
loss: 0.691353  [44864/149994]
loss: 0.681135  [51264/149994]
loss: 0.695024  [57664/149994]
loss: 0.682919  [64064/149994]
loss: 0.704482  [70464/149994]
loss: 0.696044  [76864/149994]
loss: 0.696560  [83264/149994]
loss: 0.706651  [89664/149994]
loss: 0.702506  [96064/149994]
loss: 0.694556  [102464/149994]
loss: 0.689676  [108864/149994]
loss: 0.691323  [115264/149994]
loss: 0.696038  [121664/149994]
loss: 0.694284  [128064/149994]
loss: 0.687893  [134464/149994]
loss: 0.694255  [140864/149994]
loss: 0.684676  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694681 

Epoch 100
-------------------------------
loss: 0.694356  [   64/149994]
loss: 0.692712  [ 6464/149994]

loss: 0.701074  [115264/149994]
loss: 0.696497  [121664/149994]
loss: 0.696106  [128064/149994]
loss: 0.694572  [134464/149994]
loss: 0.696515  [140864/149994]
loss: 0.696211  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694716 

Epoch 109
-------------------------------
loss: 0.695809  [   64/149994]
loss: 0.691420  [ 6464/149994]
loss: 0.696063  [12864/149994]
loss: 0.691092  [19264/149994]
loss: 0.692764  [25664/149994]
loss: 0.706295  [32064/149994]
loss: 0.697849  [38464/149994]
loss: 0.691406  [44864/149994]
loss: 0.709597  [51264/149994]
loss: 0.691221  [57664/149994]
loss: 0.691560  [64064/149994]
loss: 0.701323  [70464/149994]
loss: 0.687523  [76864/149994]
loss: 0.697957  [83264/149994]
loss: 0.691129  [89664/149994]
loss: 0.693054  [96064/149994]
loss: 0.681455  [102464/149994]
loss: 0.688078  [108864/149994]
loss: 0.704758  [115264/149994]
loss: 0.701379  [121664/149994]
loss: 0.690913  [128064/149994]
loss: 0.684476  [134464/149994]
loss: 0.694609  [140864/14

loss: 0.687691  [76864/149994]
loss: 0.689579  [83264/149994]
loss: 0.699206  [89664/149994]
loss: 0.697950  [96064/149994]
loss: 0.692811  [102464/149994]
loss: 0.696682  [108864/149994]
loss: 0.689433  [115264/149994]
loss: 0.693098  [121664/149994]
loss: 0.692802  [128064/149994]
loss: 0.691373  [134464/149994]
loss: 0.692455  [140864/149994]
loss: 0.694813  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694715 

Epoch 119
-------------------------------
loss: 0.700922  [   64/149994]
loss: 0.697785  [ 6464/149994]
loss: 0.690895  [12864/149994]
loss: 0.692763  [19264/149994]
loss: 0.706240  [25664/149994]
loss: 0.696621  [32064/149994]
loss: 0.714256  [38464/149994]
loss: 0.691394  [44864/149994]
loss: 0.689646  [51264/149994]
loss: 0.689820  [57664/149994]
loss: 0.692961  [64064/149994]
loss: 0.696157  [70464/149994]
loss: 0.701062  [76864/149994]
loss: 0.699087  [83264/149994]
loss: 0.694483  [89664/149994]
loss: 0.692689  [96064/149994]
loss: 0.690884  [102464/149994

loss: 0.689328  [38464/149994]
loss: 0.696200  [44864/149994]
loss: 0.694141  [51264/149994]
loss: 0.694193  [57664/149994]
loss: 0.699802  [64064/149994]
loss: 0.697887  [70464/149994]
loss: 0.687922  [76864/149994]
loss: 0.688014  [83264/149994]
loss: 0.689489  [89664/149994]
loss: 0.682760  [96064/149994]
loss: 0.704662  [102464/149994]
loss: 0.696320  [108864/149994]
loss: 0.696114  [115264/149994]
loss: 0.694914  [121664/149994]
loss: 0.694619  [128064/149994]
loss: 0.697804  [134464/149994]
loss: 0.691343  [140864/149994]
loss: 0.699672  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694607 

Epoch 129
-------------------------------
loss: 0.700790  [   64/149994]
loss: 0.699831  [ 6464/149994]
loss: 0.704623  [12864/149994]
loss: 0.679727  [19264/149994]
loss: 0.691219  [25664/149994]
loss: 0.692790  [32064/149994]
loss: 0.701579  [38464/149994]
loss: 0.684223  [44864/149994]
loss: 0.696409  [51264/149994]
loss: 0.686132  [57664/149994]
loss: 0.699696  [64064/149994]

loss: 0.699240  [19264/149994]
loss: 0.690015  [25664/149994]
loss: 0.696396  [32064/149994]
loss: 0.693115  [38464/149994]
loss: 0.696237  [44864/149994]
loss: 0.692824  [51264/149994]
loss: 0.699516  [57664/149994]
loss: 0.693023  [64064/149994]
loss: 0.694561  [70464/149994]
loss: 0.696176  [76864/149994]
loss: 0.699354  [83264/149994]
loss: 0.686453  [89664/149994]
loss: 0.692964  [96064/149994]
loss: 0.695996  [102464/149994]
loss: 0.684326  [108864/149994]
loss: 0.699652  [115264/149994]
loss: 0.692718  [121664/149994]
loss: 0.686160  [128064/149994]
loss: 0.697928  [134464/149994]
loss: 0.703091  [140864/149994]
loss: 0.694505  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694760 

Epoch 139
-------------------------------
loss: 0.689207  [   64/149994]
loss: 0.687652  [ 6464/149994]
loss: 0.689854  [12864/149994]
loss: 0.711194  [19264/149994]
loss: 0.694695  [25664/149994]
loss: 0.691477  [32064/149994]
loss: 0.677716  [38464/149994]
loss: 0.707680  [44864/149994]

Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694716 

Epoch 148
-------------------------------
loss: 0.702890  [   64/149994]
loss: 0.687619  [ 6464/149994]
loss: 0.689744  [12864/149994]
loss: 0.701309  [19264/149994]
loss: 0.699614  [25664/149994]
loss: 0.694364  [32064/149994]
loss: 0.701131  [38464/149994]
loss: 0.684403  [44864/149994]
loss: 0.694248  [51264/149994]
loss: 0.704604  [57664/149994]
loss: 0.697853  [64064/149994]
loss: 0.692865  [70464/149994]
loss: 0.700827  [76864/149994]
loss: 0.688096  [83264/149994]
loss: 0.698163  [89664/149994]
loss: 0.697885  [96064/149994]
loss: 0.696156  [102464/149994]
loss: 0.691303  [108864/149994]
loss: 0.683200  [115264/149994]
loss: 0.700795  [121664/149994]
loss: 0.689735  [128064/149994]
loss: 0.691018  [134464/149994]
loss: 0.702666  [140864/149994]
loss: 0.692703  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694715 

Epoch 149
-------------------------------
loss: 0.695961  [   64/149994]
loss: 0.699467  [ 6464/149994

loss: 0.691120  [115264/149994]
loss: 0.689543  [121664/149994]
loss: 0.694165  [128064/149994]
loss: 0.692185  [134464/149994]
loss: 0.697605  [140864/149994]
loss: 0.699630  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694794 

Epoch 158
-------------------------------
loss: 0.689738  [   64/149994]
loss: 0.709416  [ 6464/149994]
loss: 0.697733  [12864/149994]
loss: 0.686200  [19264/149994]
loss: 0.691151  [25664/149994]
loss: 0.687870  [32064/149994]
loss: 0.686510  [38464/149994]
loss: 0.697735  [44864/149994]
loss: 0.697618  [51264/149994]
loss: 0.689832  [57664/149994]
loss: 0.694451  [64064/149994]
loss: 0.695928  [70464/149994]
loss: 0.691343  [76864/149994]
loss: 0.696020  [83264/149994]
loss: 0.676138  [89664/149994]
loss: 0.694720  [96064/149994]
loss: 0.694239  [102464/149994]
loss: 0.699707  [108864/149994]
loss: 0.681336  [115264/149994]
loss: 0.691381  [121664/149994]
loss: 0.692927  [128064/149994]
loss: 0.686106  [134464/149994]
loss: 0.696194  [140864/14

loss: 0.696666  [76864/149994]
loss: 0.687962  [83264/149994]
loss: 0.697650  [89664/149994]
loss: 0.686587  [96064/149994]
loss: 0.704752  [102464/149994]
loss: 0.689815  [108864/149994]
loss: 0.696009  [115264/149994]
loss: 0.697721  [121664/149994]
loss: 0.694710  [128064/149994]
loss: 0.697699  [134464/149994]
loss: 0.694658  [140864/149994]
loss: 0.694510  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694646 

Epoch 168
-------------------------------
loss: 0.701512  [   64/149994]
loss: 0.696207  [ 6464/149994]
loss: 0.697693  [12864/149994]
loss: 0.707790  [19264/149994]
loss: 0.689338  [25664/149994]
loss: 0.700999  [32064/149994]
loss: 0.704479  [38464/149994]
loss: 0.694767  [44864/149994]
loss: 0.696599  [51264/149994]
loss: 0.698203  [57664/149994]
loss: 0.689806  [64064/149994]
loss: 0.689795  [70464/149994]
loss: 0.687925  [76864/149994]
loss: 0.686201  [83264/149994]
loss: 0.697927  [89664/149994]
loss: 0.696383  [96064/149994]
loss: 0.696314  [102464/149994

loss: 0.704528  [38464/149994]
loss: 0.692773  [44864/149994]
loss: 0.691059  [51264/149994]
loss: 0.702853  [57664/149994]
loss: 0.715959  [64064/149994]
loss: 0.709526  [70464/149994]
loss: 0.702970  [76864/149994]
loss: 0.698083  [83264/149994]
loss: 0.699652  [89664/149994]
loss: 0.687888  [96064/149994]
loss: 0.692694  [102464/149994]
loss: 0.689591  [108864/149994]
loss: 0.692958  [115264/149994]
loss: 0.697772  [121664/149994]
loss: 0.701081  [128064/149994]
loss: 0.699794  [134464/149994]
loss: 0.703081  [140864/149994]
loss: 0.707663  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694718 

Epoch 178
-------------------------------
loss: 0.691144  [   64/149994]
loss: 0.711182  [ 6464/149994]
loss: 0.694646  [12864/149994]
loss: 0.694536  [19264/149994]
loss: 0.696223  [25664/149994]
loss: 0.691271  [32064/149994]
loss: 0.687924  [38464/149994]
loss: 0.694319  [44864/149994]
loss: 0.694228  [51264/149994]
loss: 0.702939  [57664/149994]
loss: 0.697852  [64064/149994]

loss: 0.694578  [19264/149994]
loss: 0.681373  [25664/149994]
loss: 0.689623  [32064/149994]
loss: 0.689172  [38464/149994]
loss: 0.696148  [44864/149994]
loss: 0.706320  [51264/149994]
loss: 0.699730  [57664/149994]
loss: 0.686242  [64064/149994]
loss: 0.694526  [70464/149994]
loss: 0.702785  [76864/149994]
loss: 0.699499  [83264/149994]
loss: 0.687754  [89664/149994]
loss: 0.689734  [96064/149994]
loss: 0.694660  [102464/149994]
loss: 0.689538  [108864/149994]
loss: 0.691256  [115264/149994]
loss: 0.694672  [121664/149994]
loss: 0.692798  [128064/149994]
loss: 0.699639  [134464/149994]
loss: 0.704194  [140864/149994]
loss: 0.691202  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694721 

Epoch 188
-------------------------------
loss: 0.697714  [   64/149994]
loss: 0.693171  [ 6464/149994]
loss: 0.692858  [12864/149994]
loss: 0.689622  [19264/149994]
loss: 0.696532  [25664/149994]
loss: 0.693003  [32064/149994]
loss: 0.704711  [38464/149994]
loss: 0.703296  [44864/149994]

Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694756 

Epoch 197
-------------------------------
loss: 0.680864  [   64/149994]
loss: 0.679610  [ 6464/149994]
loss: 0.686106  [12864/149994]
loss: 0.696342  [19264/149994]
loss: 0.699662  [25664/149994]
loss: 0.696338  [32064/149994]
loss: 0.697758  [38464/149994]
loss: 0.699726  [44864/149994]
loss: 0.689641  [51264/149994]
loss: 0.686231  [57664/149994]
loss: 0.696393  [64064/149994]
loss: 0.689881  [70464/149994]
loss: 0.697828  [76864/149994]
loss: 0.699229  [83264/149994]
loss: 0.689411  [89664/149994]
loss: 0.691298  [96064/149994]
loss: 0.702639  [102464/149994]
loss: 0.696305  [108864/149994]
loss: 0.694297  [115264/149994]
loss: 0.691512  [121664/149994]
loss: 0.693515  [128064/149994]
loss: 0.689766  [134464/149994]
loss: 0.692944  [140864/149994]
loss: 0.703174  [147264/149994]
Test Error: 
 Accuracy: 49.8%, Avg loss: 0.694795 

Epoch 198
-------------------------------
loss: 0.708102  [   64/149994]
loss: 0.699481  [ 6464/149994

In [None]:
train_losses

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(5, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.xscale('log')
plt.ylabel('Loss')
plt.legend()
plt.show()