In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils import data

import pandas as pd
import numpy as np

import os
import os.path as op

path = '../../data/kaggle-titanic'
for dirname, _, filenames in os.walk(path):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        
setNames = ['train_RescaleClean.csv','train_label.csv','test_RescaleClean.csv']
dfs = []
for sn in setNames:
    dfs.append( pd.read_csv(op.join(path,sn), index_col=0))
    print(dfs[-1].head())
    print(dfs[-1].info())
    
with open(op.join(path,"significantFeatures.txt"), "r") as file:
    features = [f.split('\n')[0] for f in file.readlines()]
print(features)

../../data/kaggle-titanic/train_label.csv
../../data/kaggle-titanic/test_RescaleClean.csv
../../data/kaggle-titanic/train_RescaleClean.csv
../../data/kaggle-titanic/test.csv
../../data/kaggle-titanic/submission.csv
../../data/kaggle-titanic/train.csv
../../data/kaggle-titanic/significantFeatures.txt
   Pclass  Sex  SibSp  Parch  Fare_group  Embarked  Age_group  Title
0     1.0  0.0  0.125    0.0         0.0       0.0        0.2    0.0
1     0.0  1.0  0.125    0.0         0.0       0.5        0.4    0.5
2     1.0  1.0  0.000    0.0         0.0       0.0        0.2    1.0
3     0.0  1.0  0.125    0.0         0.0       0.0        0.4    0.5
4     1.0  0.0  0.000    0.0         0.0       0.0        0.4    0.0
<class 'pandas.core.frame.DataFrame'>
Int64Index: 891 entries, 0 to 890
Data columns (total 8 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   Pclass      891 non-null    float64
 1   Sex         891 non-null    float64
 2   SibSp       

In [2]:
# Download training data from open datasets.
class customTitanic(data.Dataset):
    
    def __init__(self, df, features, y):
        
        self.df_selected = df[features] #type: pandas.core.frame.DataFrame
        self.selected = np.stack([c.values for n, c in self.df_selected.items()], axis=1).astype(np.float32) #tpye: numpy.ndarray
    
        self.y = y.values.astype(np.float32)
        
    def __len__(self): return len(self.y)

    def __getitem__(self, idx):
        
        return [self.selected[idx], self.y[idx]]

In [3]:
X, y = dfs[0], dfs[1]['Survived']
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1)
training_data = customTitanic(X_train, features,y_train)
test_data = customTitanic(X_test, features, y_test)

batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
print(test_dataloader)
for Xt, yt in test_dataloader:
    print(f"Shape of X : {Xt.shape}")
    print(f"Shape of y: {yt.shape} {yt.dtype}")
    break

<torch.utils.data.dataloader.DataLoader object at 0x7f2c850f6910>
Shape of X : torch.Size([64, 7])
Shape of y: torch.Size([64]) torch.float32


In [4]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(7, 20),
            nn.ReLU(),
            nn.Linear(20, 1),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

Using cpu device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=7, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=1, bias=True)
  )
)


In [5]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [6]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        
        loss = loss_fn(pred.squeeze(), y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss, current = loss.item(), (batch+1) * len(X)
        print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [7]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred.squeeze(), y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [8]:
epochs = 500
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 0.677829  [   64/  801]
loss: 0.668002  [  128/  801]
loss: 0.674518  [  192/  801]
loss: 0.674881  [  256/  801]
loss: 0.672066  [  320/  801]
loss: 0.672198  [  384/  801]
loss: 0.677583  [  448/  801]
loss: 0.671706  [  512/  801]
loss: 0.678611  [  576/  801]
loss: 0.669760  [  640/  801]
loss: 0.670751  [  704/  801]
loss: 0.680424  [  768/  801]
loss: 0.686861  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.674988 

Epoch 2
-------------------------------
loss: 0.677676  [   64/  801]
loss: 0.667690  [  128/  801]
loss: 0.674185  [  192/  801]
loss: 0.674611  [  256/  801]
loss: 0.671744  [  320/  801]
loss: 0.671981  [  384/  801]
loss: 0.677246  [  448/  801]
loss: 0.671368  [  512/  801]
loss: 0.678383  [  576/  801]
loss: 0.669387  [  640/  801]
loss: 0.670428  [  704/  801]
loss: 0.680236  [  768/  801]
loss: 0.686743  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.674690 

Epoch 3
----------------------------

loss: 0.669071  [  192/  801]
loss: 0.670463  [  256/  801]
loss: 0.666749  [  320/  801]
loss: 0.668637  [  384/  801]
loss: 0.672083  [  448/  801]
loss: 0.666209  [  512/  801]
loss: 0.674880  [  576/  801]
loss: 0.663681  [  640/  801]
loss: 0.665440  [  704/  801]
loss: 0.677380  [  768/  801]
loss: 0.684874  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.670119 

Epoch 19
-------------------------------
loss: 0.675193  [   64/  801]
loss: 0.662625  [  128/  801]
loss: 0.668764  [  192/  801]
loss: 0.670214  [  256/  801]
loss: 0.666446  [  320/  801]
loss: 0.668436  [  384/  801]
loss: 0.671774  [  448/  801]
loss: 0.665902  [  512/  801]
loss: 0.674669  [  576/  801]
loss: 0.663339  [  640/  801]
loss: 0.665139  [  704/  801]
loss: 0.677210  [  768/  801]
loss: 0.684759  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.669845 

Epoch 20
-------------------------------
loss: 0.675053  [   64/  801]
loss: 0.662341  [  128/  801]
loss: 0.668459  [  192/  801]
loss

loss: 0.666122  [  256/  801]
loss: 0.661425  [  320/  801]
loss: 0.665120  [  384/  801]
loss: 0.666738  [  448/  801]
loss: 0.660906  [  512/  801]
loss: 0.671220  [  576/  801]
loss: 0.657747  [  640/  801]
loss: 0.660174  [  704/  801]
loss: 0.674440  [  768/  801]
loss: 0.682790  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.665362 

Epoch 37
-------------------------------
loss: 0.672751  [   64/  801]
loss: 0.657703  [  128/  801]
loss: 0.663462  [  192/  801]
loss: 0.665888  [  256/  801]
loss: 0.661136  [  320/  801]
loss: 0.664929  [  384/  801]
loss: 0.666453  [  448/  801]
loss: 0.660625  [  512/  801]
loss: 0.671023  [  576/  801]
loss: 0.657430  [  640/  801]
loss: 0.659890  [  704/  801]
loss: 0.674283  [  768/  801]
loss: 0.682674  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.665108 

Epoch 38
-------------------------------
loss: 0.672619  [   64/  801]
loss: 0.657440  [  128/  801]
loss: 0.663179  [  192/  801]
loss: 0.665656  [  256/  801]
loss

loss: 0.680672  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.660919 

Epoch 55
-------------------------------
loss: 0.670443  [   64/  801]
loss: 0.653145  [  128/  801]
loss: 0.658529  [  192/  801]
loss: 0.661817  [  256/  801]
loss: 0.656051  [  320/  801]
loss: 0.661596  [  384/  801]
loss: 0.661511  [  448/  801]
loss: 0.655787  [  512/  801]
loss: 0.667597  [  576/  801]
loss: 0.651937  [  640/  801]
loss: 0.654908  [  704/  801]
loss: 0.671558  [  768/  801]
loss: 0.680552  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.660679 

Epoch 56
-------------------------------
loss: 0.670317  [   64/  801]
loss: 0.652900  [  128/  801]
loss: 0.658264  [  192/  801]
loss: 0.661597  [  256/  801]
loss: 0.655773  [  320/  801]
loss: 0.661415  [  384/  801]
loss: 0.661246  [  448/  801]
loss: 0.655529  [  512/  801]
loss: 0.667410  [  576/  801]
loss: 0.651643  [  640/  801]
loss: 0.654637  [  704/  801]
loss: 0.671411  [  768/  801]
loss: 0.680432  [  429/  801]
Test

loss: 0.668203  [   64/  801]
loss: 0.648870  [  128/  801]
loss: 0.653887  [  192/  801]
loss: 0.657932  [  256/  801]
loss: 0.651114  [  320/  801]
loss: 0.658383  [  384/  801]
loss: 0.656875  [  448/  801]
loss: 0.651287  [  512/  801]
loss: 0.664304  [  576/  801]
loss: 0.646774  [  640/  801]
loss: 0.650121  [  704/  801]
loss: 0.668980  [  768/  801]
loss: 0.678342  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.656493 

Epoch 74
-------------------------------
loss: 0.668079  [   64/  801]
loss: 0.648640  [  128/  801]
loss: 0.653636  [  192/  801]
loss: 0.657720  [  256/  801]
loss: 0.650842  [  320/  801]
loss: 0.658207  [  384/  801]
loss: 0.656625  [  448/  801]
loss: 0.651045  [  512/  801]
loss: 0.664124  [  576/  801]
loss: 0.646495  [  640/  801]
loss: 0.649860  [  704/  801]
loss: 0.668840  [  768/  801]
loss: 0.678216  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.656266 

Epoch 75
-------------------------------
loss: 0.667956  [   64/  801]
loss

loss: 0.642415  [  128/  801]
loss: 0.646822  [  192/  801]
loss: 0.651918  [  256/  801]
loss: 0.643314  [  320/  801]
loss: 0.653324  [  384/  801]
loss: 0.649863  [  448/  801]
loss: 0.644554  [  512/  801]
loss: 0.659176  [  576/  801]
loss: 0.638947  [  640/  801]
loss: 0.642665  [  704/  801]
loss: 0.665006  [  768/  801]
loss: 0.674503  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.650092 

Epoch 103
-------------------------------
loss: 0.664493  [   64/  801]
loss: 0.642199  [  128/  801]
loss: 0.646586  [  192/  801]
loss: 0.651714  [  256/  801]
loss: 0.643046  [  320/  801]
loss: 0.653151  [  384/  801]
loss: 0.649630  [  448/  801]
loss: 0.644331  [  512/  801]
loss: 0.659001  [  576/  801]
loss: 0.638686  [  640/  801]
loss: 0.642411  [  704/  801]
loss: 0.664872  [  768/  801]
loss: 0.674363  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.649877 

Epoch 104
-------------------------------
loss: 0.664368  [   64/  801]
loss: 0.641984  [  128/  801]
lo

loss: 0.672060  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.646467 

Epoch 120
-------------------------------
loss: 0.662355  [   64/  801]
loss: 0.638585  [  128/  801]
loss: 0.642618  [  192/  801]
loss: 0.648259  [  256/  801]
loss: 0.638493  [  320/  801]
loss: 0.650199  [  384/  801]
loss: 0.645716  [  448/  801]
loss: 0.640619  [  512/  801]
loss: 0.656040  [  576/  801]
loss: 0.634318  [  640/  801]
loss: 0.638112  [  704/  801]
loss: 0.662587  [  768/  801]
loss: 0.671912  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.646256 

Epoch 121
-------------------------------
loss: 0.662228  [   64/  801]
loss: 0.638375  [  128/  801]
loss: 0.642387  [  192/  801]
loss: 0.648057  [  256/  801]
loss: 0.638225  [  320/  801]
loss: 0.650024  [  384/  801]
loss: 0.645490  [  448/  801]
loss: 0.640405  [  512/  801]
loss: 0.655866  [  576/  801]
loss: 0.634065  [  640/  801]
loss: 0.637860  [  704/  801]
loss: 0.662453  [  768/  801]
loss: 0.671763  [  429/  801]
Te

loss: 0.633575  [  704/  801]
loss: 0.660178  [  768/  801]
loss: 0.669134  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.642478 

Epoch 139
-------------------------------
loss: 0.659902  [   64/  801]
loss: 0.634628  [  128/  801]
loss: 0.638276  [  192/  801]
loss: 0.644413  [  256/  801]
loss: 0.633376  [  320/  801]
loss: 0.646865  [  384/  801]
loss: 0.641472  [  448/  801]
loss: 0.636617  [  512/  801]
loss: 0.652739  [  576/  801]
loss: 0.629565  [  640/  801]
loss: 0.633322  [  704/  801]
loss: 0.660044  [  768/  801]
loss: 0.668973  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.642269 

Epoch 140
-------------------------------
loss: 0.659771  [   64/  801]
loss: 0.634422  [  128/  801]
loss: 0.638050  [  192/  801]
loss: 0.644211  [  256/  801]
loss: 0.633106  [  320/  801]
loss: 0.646689  [  384/  801]
loss: 0.641252  [  448/  801]
loss: 0.636410  [  512/  801]
loss: 0.652565  [  576/  801]
loss: 0.629319  [  640/  801]
loss: 0.633070  [  704/  801]
lo

loss: 0.633137  [  512/  801]
loss: 0.649768  [  576/  801]
loss: 0.625417  [  640/  801]
loss: 0.629015  [  704/  801]
loss: 0.657755  [  768/  801]
loss: 0.666148  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.638727 

Epoch 157
-------------------------------
loss: 0.657496  [   64/  801]
loss: 0.630945  [  128/  801]
loss: 0.634217  [  192/  801]
loss: 0.640764  [  256/  801]
loss: 0.628480  [  320/  801]
loss: 0.643666  [  384/  801]
loss: 0.637529  [  448/  801]
loss: 0.632933  [  512/  801]
loss: 0.649592  [  576/  801]
loss: 0.625174  [  640/  801]
loss: 0.628759  [  704/  801]
loss: 0.657618  [  768/  801]
loss: 0.665975  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.638519 

Epoch 158
-------------------------------
loss: 0.657358  [   64/  801]
loss: 0.630741  [  128/  801]
loss: 0.633992  [  192/  801]
loss: 0.640559  [  256/  801]
loss: 0.628204  [  320/  801]
loss: 0.643486  [  384/  801]
loss: 0.637310  [  448/  801]
loss: 0.632729  [  512/  801]
lo

Test Error: 
 Accuracy: 63.3%, Avg loss: 0.635180 

Epoch 174
-------------------------------
loss: 0.655117  [   64/  801]
loss: 0.627490  [  128/  801]
loss: 0.630405  [  192/  801]
loss: 0.637270  [  256/  801]
loss: 0.623764  [  320/  801]
loss: 0.640596  [  384/  801]
loss: 0.633825  [  448/  801]
loss: 0.629499  [  512/  801]
loss: 0.646572  [  576/  801]
loss: 0.621074  [  640/  801]
loss: 0.624380  [  704/  801]
loss: 0.655280  [  768/  801]
loss: 0.662911  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.634970 

Epoch 175
-------------------------------
loss: 0.654974  [   64/  801]
loss: 0.627288  [  128/  801]
loss: 0.630181  [  192/  801]
loss: 0.637064  [  256/  801]
loss: 0.623483  [  320/  801]
loss: 0.640413  [  384/  801]
loss: 0.633608  [  448/  801]
loss: 0.629298  [  512/  801]
loss: 0.646393  [  576/  801]
loss: 0.620834  [  640/  801]
loss: 0.624120  [  704/  801]
loss: 0.655141  [  768/  801]
loss: 0.662725  [  429/  801]
Test Error: 
 Accuracy: 63.3%, A

Test Error: 
 Accuracy: 63.3%, Avg loss: 0.631610 

Epoch 191
-------------------------------
loss: 0.652643  [   64/  801]
loss: 0.624040  [  128/  801]
loss: 0.626591  [  192/  801]
loss: 0.633732  [  256/  801]
loss: 0.618948  [  320/  801]
loss: 0.637457  [  384/  801]
loss: 0.630134  [  448/  801]
loss: 0.626104  [  512/  801]
loss: 0.643506  [  576/  801]
loss: 0.616998  [  640/  801]
loss: 0.619932  [  704/  801]
loss: 0.652895  [  768/  801]
loss: 0.659646  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.631399 

Epoch 192
-------------------------------
loss: 0.652494  [   64/  801]
loss: 0.623838  [  128/  801]
loss: 0.626366  [  192/  801]
loss: 0.633522  [  256/  801]
loss: 0.618661  [  320/  801]
loss: 0.637270  [  384/  801]
loss: 0.629917  [  448/  801]
loss: 0.625905  [  512/  801]
loss: 0.643323  [  576/  801]
loss: 0.616758  [  640/  801]
loss: 0.619667  [  704/  801]
loss: 0.652753  [  768/  801]
loss: 0.659447  [  429/  801]
Test Error: 
 Accuracy: 63.3%, A

loss: 0.615398  [  704/  801]
loss: 0.650454  [  768/  801]
loss: 0.656170  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.627793 

Epoch 209
-------------------------------
loss: 0.649905  [   64/  801]
loss: 0.620376  [  128/  801]
loss: 0.622532  [  192/  801]
loss: 0.629919  [  256/  801]
loss: 0.613717  [  320/  801]
loss: 0.634045  [  384/  801]
loss: 0.626228  [  448/  801]
loss: 0.622535  [  512/  801]
loss: 0.640188  [  576/  801]
loss: 0.612685  [  640/  801]
loss: 0.615128  [  704/  801]
loss: 0.650308  [  768/  801]
loss: 0.655959  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.627580 

Epoch 210
-------------------------------
loss: 0.649749  [   64/  801]
loss: 0.620171  [  128/  801]
loss: 0.622306  [  192/  801]
loss: 0.629704  [  256/  801]
loss: 0.613422  [  320/  801]
loss: 0.633853  [  384/  801]
loss: 0.626011  [  448/  801]
loss: 0.622337  [  512/  801]
loss: 0.640001  [  576/  801]
loss: 0.612445  [  640/  801]
loss: 0.614857  [  704/  801]
lo

loss: 0.608334  [  320/  801]
loss: 0.630530  [  384/  801]
loss: 0.622303  [  448/  801]
loss: 0.618973  [  512/  801]
loss: 0.636789  [  576/  801]
loss: 0.608351  [  640/  801]
loss: 0.610209  [  704/  801]
loss: 0.647651  [  768/  801]
loss: 0.652031  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.623697 

Epoch 228
-------------------------------
loss: 0.646871  [   64/  801]
loss: 0.616469  [  128/  801]
loss: 0.618201  [  192/  801]
loss: 0.625799  [  256/  801]
loss: 0.608030  [  320/  801]
loss: 0.630331  [  384/  801]
loss: 0.622084  [  448/  801]
loss: 0.618776  [  512/  801]
loss: 0.636598  [  576/  801]
loss: 0.608110  [  640/  801]
loss: 0.609932  [  704/  801]
loss: 0.647502  [  768/  801]
loss: 0.651806  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.623479 

Epoch 229
-------------------------------
loss: 0.646707  [   64/  801]
loss: 0.616262  [  128/  801]
loss: 0.617971  [  192/  801]
loss: 0.625580  [  256/  801]
loss: 0.607726  [  320/  801]
lo

loss: 0.626699  [  384/  801]
loss: 0.618130  [  448/  801]
loss: 0.615213  [  512/  801]
loss: 0.633112  [  576/  801]
loss: 0.603745  [  640/  801]
loss: 0.604885  [  704/  801]
loss: 0.644769  [  768/  801]
loss: 0.647628  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.619526 

Epoch 247
-------------------------------
loss: 0.643691  [   64/  801]
loss: 0.612515  [  128/  801]
loss: 0.613812  [  192/  801]
loss: 0.621587  [  256/  801]
loss: 0.602169  [  320/  801]
loss: 0.626495  [  384/  801]
loss: 0.617910  [  448/  801]
loss: 0.615015  [  512/  801]
loss: 0.632917  [  576/  801]
loss: 0.603501  [  640/  801]
loss: 0.604602  [  704/  801]
loss: 0.644616  [  768/  801]
loss: 0.647389  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.619305 

Epoch 248
-------------------------------
loss: 0.643520  [   64/  801]
loss: 0.612305  [  128/  801]
loss: 0.613579  [  192/  801]
loss: 0.621363  [  256/  801]
loss: 0.601857  [  320/  801]
loss: 0.626290  [  384/  801]
lo

loss: 0.617508  [  256/  801]
loss: 0.596458  [  320/  801]
loss: 0.622749  [  384/  801]
loss: 0.613949  [  448/  801]
loss: 0.611438  [  512/  801]
loss: 0.629341  [  576/  801]
loss: 0.599091  [  640/  801]
loss: 0.599433  [  704/  801]
loss: 0.641811  [  768/  801]
loss: 0.642963  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.615282 

Epoch 266
-------------------------------
loss: 0.640371  [   64/  801]
loss: 0.608515  [  128/  801]
loss: 0.609352  [  192/  801]
loss: 0.617278  [  256/  801]
loss: 0.596135  [  320/  801]
loss: 0.622537  [  384/  801]
loss: 0.613727  [  448/  801]
loss: 0.611238  [  512/  801]
loss: 0.629139  [  576/  801]
loss: 0.598844  [  640/  801]
loss: 0.599142  [  704/  801]
loss: 0.641652  [  768/  801]
loss: 0.642707  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.615055 

Epoch 267
-------------------------------
loss: 0.640193  [   64/  801]
loss: 0.608302  [  128/  801]
loss: 0.609114  [  192/  801]
loss: 0.617048  [  256/  801]
lo

loss: 0.605291  [  192/  801]
loss: 0.613328  [  256/  801]
loss: 0.590577  [  320/  801]
loss: 0.618893  [  384/  801]
loss: 0.609954  [  448/  801]
loss: 0.607819  [  512/  801]
loss: 0.625657  [  576/  801]
loss: 0.594627  [  640/  801]
loss: 0.594134  [  704/  801]
loss: 0.638923  [  768/  801]
loss: 0.638267  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.611165 

Epoch 284
-------------------------------
loss: 0.637097  [   64/  801]
loss: 0.604669  [  128/  801]
loss: 0.605051  [  192/  801]
loss: 0.613094  [  256/  801]
loss: 0.590246  [  320/  801]
loss: 0.618677  [  384/  801]
loss: 0.609731  [  448/  801]
loss: 0.607618  [  512/  801]
loss: 0.625450  [  576/  801]
loss: 0.594377  [  640/  801]
loss: 0.593836  [  704/  801]
loss: 0.638761  [  768/  801]
loss: 0.638001  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.610934 

Epoch 285
-------------------------------
loss: 0.636912  [   64/  801]
loss: 0.604454  [  128/  801]
loss: 0.604810  [  192/  801]
lo

loss: 0.588416  [  704/  801]
loss: 0.635826  [  768/  801]
loss: 0.633089  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.606739 

Epoch 303
-------------------------------
loss: 0.633528  [   64/  801]
loss: 0.600554  [  128/  801]
loss: 0.600450  [  192/  801]
loss: 0.608581  [  256/  801]
loss: 0.583871  [  320/  801]
loss: 0.614522  [  384/  801]
loss: 0.605473  [  448/  801]
loss: 0.603780  [  512/  801]
loss: 0.621467  [  576/  801]
loss: 0.589603  [  640/  801]
loss: 0.588110  [  704/  801]
loss: 0.635661  [  768/  801]
loss: 0.632810  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.606502 

Epoch 304
-------------------------------
loss: 0.633336  [   64/  801]
loss: 0.600335  [  128/  801]
loss: 0.600206  [  192/  801]
loss: 0.608341  [  256/  801]
loss: 0.583530  [  320/  801]
loss: 0.614301  [  384/  801]
loss: 0.605247  [  448/  801]
loss: 0.603577  [  512/  801]
loss: 0.621255  [  576/  801]
loss: 0.589349  [  640/  801]
loss: 0.587804  [  704/  801]
lo

loss: 0.578011  [  320/  801]
loss: 0.610716  [  384/  801]
loss: 0.601601  [  448/  801]
loss: 0.600318  [  512/  801]
loss: 0.617833  [  576/  801]
loss: 0.585263  [  640/  801]
loss: 0.582853  [  704/  801]
loss: 0.632833  [  768/  801]
loss: 0.627962  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.602447 

Epoch 321
-------------------------------
loss: 0.630026  [   64/  801]
loss: 0.596574  [  128/  801]
loss: 0.596012  [  192/  801]
loss: 0.604213  [  256/  801]
loss: 0.577662  [  320/  801]
loss: 0.610489  [  384/  801]
loss: 0.601372  [  448/  801]
loss: 0.600114  [  512/  801]
loss: 0.617617  [  576/  801]
loss: 0.585006  [  640/  801]
loss: 0.582540  [  704/  801]
loss: 0.632665  [  768/  801]
loss: 0.627671  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.602207 

Epoch 322
-------------------------------
loss: 0.629828  [   64/  801]
loss: 0.596351  [  128/  801]
loss: 0.595764  [  192/  801]
loss: 0.603968  [  256/  801]
loss: 0.577312  [  320/  801]
lo

loss: 0.600019  [  256/  801]
loss: 0.571665  [  320/  801]
loss: 0.606609  [  384/  801]
loss: 0.597466  [  448/  801]
loss: 0.596646  [  512/  801]
loss: 0.613921  [  576/  801]
loss: 0.580628  [  640/  801]
loss: 0.577181  [  704/  801]
loss: 0.629797  [  768/  801]
loss: 0.622626  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.598090 

Epoch 339
-------------------------------
loss: 0.626426  [   64/  801]
loss: 0.592544  [  128/  801]
loss: 0.591508  [  192/  801]
loss: 0.599770  [  256/  801]
loss: 0.571309  [  320/  801]
loss: 0.606379  [  384/  801]
loss: 0.597236  [  448/  801]
loss: 0.596441  [  512/  801]
loss: 0.613702  [  576/  801]
loss: 0.580370  [  640/  801]
loss: 0.576862  [  704/  801]
loss: 0.629628  [  768/  801]
loss: 0.622323  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.597845 

Epoch 340
-------------------------------
loss: 0.626224  [   64/  801]
loss: 0.592319  [  128/  801]
loss: 0.591257  [  192/  801]
loss: 0.599521  [  256/  801]
lo

loss: 0.610148  [  576/  801]
loss: 0.576216  [  640/  801]
loss: 0.571711  [  704/  801]
loss: 0.626892  [  768/  801]
loss: 0.617376  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.593914 

Epoch 356
-------------------------------
loss: 0.622945  [   64/  801]
loss: 0.588697  [  128/  801]
loss: 0.587201  [  192/  801]
loss: 0.595504  [  256/  801]
loss: 0.565157  [  320/  801]
loss: 0.602441  [  384/  801]
loss: 0.593270  [  448/  801]
loss: 0.592955  [  512/  801]
loss: 0.609923  [  576/  801]
loss: 0.575956  [  640/  801]
loss: 0.571386  [  704/  801]
loss: 0.626720  [  768/  801]
loss: 0.617063  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.593667 

Epoch 357
-------------------------------
loss: 0.622738  [   64/  801]
loss: 0.588469  [  128/  801]
loss: 0.586946  [  192/  801]
loss: 0.595251  [  256/  801]
loss: 0.564792  [  320/  801]
loss: 0.602208  [  384/  801]
loss: 0.593036  [  448/  801]
loss: 0.592750  [  512/  801]
loss: 0.609698  [  576/  801]
lo

loss: 0.558503  [  320/  801]
loss: 0.598206  [  384/  801]
loss: 0.589028  [  448/  801]
loss: 0.589245  [  512/  801]
loss: 0.605846  [  576/  801]
loss: 0.571238  [  640/  801]
loss: 0.565476  [  704/  801]
loss: 0.623604  [  768/  801]
loss: 0.611326  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.589182 

Epoch 375
-------------------------------
loss: 0.618960  [   64/  801]
loss: 0.584350  [  128/  801]
loss: 0.582327  [  192/  801]
loss: 0.590649  [  256/  801]
loss: 0.558129  [  320/  801]
loss: 0.597969  [  384/  801]
loss: 0.588792  [  448/  801]
loss: 0.589039  [  512/  801]
loss: 0.605618  [  576/  801]
loss: 0.570975  [  640/  801]
loss: 0.565144  [  704/  801]
loss: 0.623429  [  768/  801]
loss: 0.611003  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.588931 

Epoch 376
-------------------------------
loss: 0.618748  [   64/  801]
loss: 0.584120  [  128/  801]
loss: 0.582069  [  192/  801]
loss: 0.590391  [  256/  801]
loss: 0.557755  [  320/  801]
lo

loss: 0.614891  [   64/  801]
loss: 0.579963  [  128/  801]
loss: 0.577394  [  192/  801]
loss: 0.585723  [  256/  801]
loss: 0.550954  [  320/  801]
loss: 0.593434  [  384/  801]
loss: 0.584278  [  448/  801]
loss: 0.585107  [  512/  801]
loss: 0.601250  [  576/  801]
loss: 0.565949  [  640/  801]
loss: 0.558793  [  704/  801]
loss: 0.620111  [  768/  801]
loss: 0.604766  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.584118 

Epoch 395
-------------------------------
loss: 0.614674  [   64/  801]
loss: 0.579731  [  128/  801]
loss: 0.577133  [  192/  801]
loss: 0.585462  [  256/  801]
loss: 0.550573  [  320/  801]
loss: 0.593194  [  384/  801]
loss: 0.584040  [  448/  801]
loss: 0.584900  [  512/  801]
loss: 0.601018  [  576/  801]
loss: 0.565683  [  640/  801]
loss: 0.558456  [  704/  801]
loss: 0.619936  [  768/  801]
loss: 0.604433  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.583864 

Epoch 396
-------------------------------
loss: 0.614458  [   64/  801]
lo

loss: 0.552690  [  704/  801]
loss: 0.616964  [  768/  801]
loss: 0.598729  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.579516 

Epoch 413
-------------------------------
loss: 0.610755  [   64/  801]
loss: 0.575546  [  128/  801]
loss: 0.572420  [  192/  801]
loss: 0.580740  [  256/  801]
loss: 0.543658  [  320/  801]
loss: 0.588857  [  384/  801]
loss: 0.579737  [  448/  801]
loss: 0.581168  [  512/  801]
loss: 0.596843  [  576/  801]
loss: 0.560883  [  640/  801]
loss: 0.552349  [  704/  801]
loss: 0.616789  [  768/  801]
loss: 0.598391  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.579260 

Epoch 414
-------------------------------
loss: 0.610536  [   64/  801]
loss: 0.575313  [  128/  801]
loss: 0.572157  [  192/  801]
loss: 0.580476  [  256/  801]
loss: 0.543271  [  320/  801]
loss: 0.588616  [  384/  801]
loss: 0.579498  [  448/  801]
loss: 0.580961  [  512/  801]
loss: 0.596610  [  576/  801]
loss: 0.560616  [  640/  801]
loss: 0.552007  [  704/  801]
lo

loss: 0.592655  [  576/  801]
loss: 0.556056  [  640/  801]
loss: 0.546172  [  704/  801]
loss: 0.613644  [  768/  801]
loss: 0.592252  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.574626 

Epoch 432
-------------------------------
loss: 0.606586  [   64/  801]
loss: 0.571116  [  128/  801]
loss: 0.567409  [  192/  801]
loss: 0.575698  [  256/  801]
loss: 0.536268  [  320/  801]
loss: 0.584255  [  384/  801]
loss: 0.575174  [  448/  801]
loss: 0.577225  [  512/  801]
loss: 0.592422  [  576/  801]
loss: 0.555787  [  640/  801]
loss: 0.545827  [  704/  801]
loss: 0.613470  [  768/  801]
loss: 0.591909  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.574368 

Epoch 433
-------------------------------
loss: 0.606367  [   64/  801]
loss: 0.570883  [  128/  801]
loss: 0.567143  [  192/  801]
loss: 0.575430  [  256/  801]
loss: 0.535878  [  320/  801]
loss: 0.584013  [  384/  801]
loss: 0.574933  [  448/  801]
loss: 0.577017  [  512/  801]
loss: 0.592190  [  576/  801]
lo

loss: 0.602385  [   64/  801]
loss: 0.566675  [  128/  801]
loss: 0.562336  [  192/  801]
loss: 0.570599  [  256/  801]
loss: 0.528790  [  320/  801]
loss: 0.579643  [  384/  801]
loss: 0.570603  [  448/  801]
loss: 0.573290  [  512/  801]
loss: 0.588000  [  576/  801]
loss: 0.550690  [  640/  801]
loss: 0.539239  [  704/  801]
loss: 0.610192  [  768/  801]
loss: 0.585318  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.569455 

Epoch 452
-------------------------------
loss: 0.602163  [   64/  801]
loss: 0.566441  [  128/  801]
loss: 0.562068  [  192/  801]
loss: 0.570329  [  256/  801]
loss: 0.528393  [  320/  801]
loss: 0.579400  [  384/  801]
loss: 0.570363  [  448/  801]
loss: 0.573083  [  512/  801]
loss: 0.587767  [  576/  801]
loss: 0.550422  [  640/  801]
loss: 0.538891  [  704/  801]
loss: 0.610021  [  768/  801]
loss: 0.584968  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.569196 

Epoch 453
-------------------------------
loss: 0.601941  [   64/  801]
lo

loss: 0.565478  [  256/  801]
loss: 0.521234  [  320/  801]
loss: 0.575030  [  384/  801]
loss: 0.566037  [  448/  801]
loss: 0.569373  [  512/  801]
loss: 0.583575  [  576/  801]
loss: 0.545588  [  640/  801]
loss: 0.532598  [  704/  801]
loss: 0.606948  [  768/  801]
loss: 0.578629  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.564533 

Epoch 471
-------------------------------
loss: 0.597932  [   64/  801]
loss: 0.562001  [  128/  801]
loss: 0.556960  [  192/  801]
loss: 0.565208  [  256/  801]
loss: 0.520834  [  320/  801]
loss: 0.574787  [  384/  801]
loss: 0.565797  [  448/  801]
loss: 0.569167  [  512/  801]
loss: 0.583342  [  576/  801]
loss: 0.545319  [  640/  801]
loss: 0.532248  [  704/  801]
loss: 0.606778  [  768/  801]
loss: 0.578275  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.564275 

Epoch 472
-------------------------------
loss: 0.597709  [   64/  801]
loss: 0.561767  [  128/  801]
loss: 0.556691  [  192/  801]
loss: 0.564939  [  256/  801]
lo

loss: 0.565488  [  512/  801]
loss: 0.579182  [  576/  801]
loss: 0.540519  [  640/  801]
loss: 0.525977  [  704/  801]
loss: 0.603761  [  768/  801]
loss: 0.571926  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.559663 

Epoch 490
-------------------------------
loss: 0.593734  [   64/  801]
loss: 0.557613  [  128/  801]
loss: 0.551848  [  192/  801]
loss: 0.560098  [  256/  801]
loss: 0.513280  [  320/  801]
loss: 0.570198  [  384/  801]
loss: 0.561248  [  448/  801]
loss: 0.565285  [  512/  801]
loss: 0.578951  [  576/  801]
loss: 0.540253  [  640/  801]
loss: 0.525628  [  704/  801]
loss: 0.603595  [  768/  801]
loss: 0.571572  [  429/  801]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.559408 

Epoch 491
-------------------------------
loss: 0.593513  [   64/  801]
loss: 0.557382  [  128/  801]
loss: 0.551579  [  192/  801]
loss: 0.559829  [  256/  801]
loss: 0.512882  [  320/  801]
loss: 0.569957  [  384/  801]
loss: 0.561009  [  448/  801]
loss: 0.565081  [  512/  801]
lo

In [9]:
# torch.save(model, './model.pth')