In [None]:
import os
import random
import numpy as np
import torch
import torchvision
from torchsummary import summary

from project_18408.datasets import *
from project_18408.training import *
from project_18408.evaluation import *
from project_18408.models.relu_toy_models import *
from project_18408.utils import *

In [None]:
print("PyTorch Version:", torch.__version__)
print("Torchvision Version:", torchvision.__version__)
# Detect if we have a GPU available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using the GPU!")
else:
    print("WARNING: Could not find GPU! Using CPU only")

In [None]:
data_dir = get_rel_pkg_path("dataset/")
weights_dir = get_rel_pkg_path("weights/")
session_dir = get_rel_pkg_path("sessions/")

In [None]:
dataset_type = ImageDatasetType.MNIST

In [None]:
orig_datasets = get_img_dataset(data_dir, dataset_type)

In [None]:
print("Training")
show_dataset_samples_img(orig_datasets['train'], cmap='gray')

print("Testing")
show_dataset_samples_img(orig_datasets['test'], cmap='gray')

In [None]:
datasets = apply_img_transforms(orig_datasets, dataset_type, flatten=True)

In [None]:
dataloaders = get_dataloaders(datasets, 128, 128, num_workers=0)

In [None]:
input_dim = IMG_DATASET_TO_IMG_SIZE_FLAT[dataset_type]
num_classes = IMG_DATASET_TO_NUM_CLASSES[dataset_type]

model = ReLUToyModel(input_dim, num_classes, layer_dims=[100]*8)
model = model.to(device)

In [None]:
summary(model, (input_dim,))

In [None]:
num_epochs = 60

criterion = get_loss()
criterion = criterion.to(device)

optimizer = make_optimizer(model, optimzer_type=OptimizerType.SGD_MOMENTUM,
                           lr=0.001, weight_decay=1e-5, verbose=False)

In [None]:
tracker = train_model(device=device,
                      model=model,
                      dataloaders=dataloaders,
                      criterion=criterion,
                      optimizer=optimizer,
                      lr_scheduler=None,
                      save_log=True,
                      save_model=True,
                      save_dir=weights_dir,
                      save_best=False,
                      save_latest=True,
                      save_all=False,
                      num_epochs=num_epochs)

In [None]:
print(tracker.save_dir)

In [None]:
for n in dataloaders['test']: print(n[0].shape); break;

In [None]:
%qtconsole

In [None]:
model

In [None]:
for n in dataloaders['test']: break

In [None]:
n[0].shape