In [1]:
# Standard Libraries
import os
import json
import math
import numpy as np
import copy

# Imports for plotting
import matplotlib.pyplot as plt
from matplotlib import _cm
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf')
import seaborn as sns
sns.set()

# Progress bar
from tqdm.notebook import tqdm

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

  set_matplotlib_formats('svg', 'pdf')


In [8]:
# Same set_seed function as in Tutorial 3

# Path to the dataset folder
DATASET_PATH =  "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial4"

In [3]:
# Function for setting the seed
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [4]:
# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [7]:
# Fetching the device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", device)

Using device cpu


In [9]:
# Downloading pretrained models 

import urllib.request
from urllib.error import HTTPError

# Github URL for saved models
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/"

# Files to download 
pretrained_files = ["FashionMNIST_SGD.config",    "FashionMNIST_SGD_results.json",    "FashionMNIST_SGD.tar", 
                    "FashionMNIST_SGDMom.config", "FashionMNIST_SGDMom_results.json", "FashionMNIST_SGDMom.tar", 
                    "FashionMNIST_Adam.config",   "FashionMNIST_Adam_results.json",   "FashionMNIST_Adam.tar"   ]

# Create a checlpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exist.
# If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong: ", e)
            

Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGD.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGD_results.json...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGD.tar...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGDMom.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGDMom_results.json...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGDMom.tar...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_Adam.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_Adam_results.json...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_Adam.tar..