In [None]:
# add src to the notebook path
import sys
sys.path.append("../src/")

# import from src
from utils import *
from models.model import *
from data.dataset import *
from constants import *

# import python libraries
import os
import PIL
import glob
import wandb
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchinfo import summary


os.environ['WANDB_API_KEY'] = wandb_key
os.environ['WANDB_NOTEBOOK_NAME'] = os.getcwd()
device = "cuda" if torch.cuda.is_available() else "cpu"
device

# Training

In [None]:
from pprint import pprint

config = load_yaml_as_dict("../config/config.yaml")
# Prints the nicely formatted dictionary
pprint(config)

In [None]:
testing_images = config["data_path"]

image_size = (config["image_size"], config["image_size"])
transform = T.Compose([
        #T.ToPILImage(),
        #T.CenterCrop(0.75 * 64),
        T.Resize(image_size),
        #T.RandomResizedCrop(image_size),
        #T.RandomHorizontalFlip(),
        T.ToTensor()
        ])

train_set = MP3DL_Dataset(testing_images, transform, load_all=True)
train_loader = DataLoader(train_set, batch_size=config["training"]["batch_size"], shuffle=config["training"]["shuffle"])

In [None]:
model = MyModel()

summary(model)

In [None]:
loss_fn = select_loss_function(config["training"]["loss_fn"])
# Optimizers specified in the torch.optim package
optimizer = select_optimizer(model, config["training"]["optim"], learning_rate=config["training"]["lr"], momentum=0.9)
scheduler = None #select_scheduler(config["training"]["scheduler"], optimizer, last_epoch=-1)

In [None]:
# start a new wandb run to track this script
wandb.init(project="my-awesome-project", config=config, mode="online", settings=wandb.Settings(disable_job_creation=True)) # mode: [online, disabled]
    
# simulate training
epochs = config["training"]["epochs"]

model = model.to(device)
for epoch in tqdm(range(epochs)):
    loss = train_one_epoch(model, train_loader, loss_fn, optimizer, scheduler)
    #acc  = evaluate(model, validation_loader)
    # log metrics to wandb
    wandb.log({"loss": loss})
    
# [optional] finish the wandb run, necessary in notebooks
wandb.finish()