In [1]:
# Import packages
import os
import random
from PIL import Image
from torch import nn
import torch 
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import json 
from src.config import PROCESSED_DATA_DIR, RAW_DATA_DIR
import torch.optim as optim

from loguru import logger

from utils import train_validate_model, modify_model_output, test_model

import time
import wandb

image_path = PROCESSED_DATA_DIR / "pizza_hamburger_hotdog_20_percent"
train_dir = image_path / 'train' 
test_dir = image_path / 'test'
valid_dir = image_path / 'valid'

device ='cuda' if torch.cuda.is_available() else 'cpu'

[32m2024-06-23 13:01:05.063[0m | [1mINFO    [0m | [36msrc.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: C:\Git\hamburger-hotdog-pizza-classifier[0m


In [2]:
wandb.init(project='pizza_hamburger_hotdog_20_percent')

[34m[1mwandb[0m: Currently logged in as: [33mdtiourine[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
config_path = 'params.json'

with open(config_path, 'r') as config_file:
   config = json.load(config_file)
#print(json.dumps(config, indent=2))

In [4]:
# Load hyperparams from params.json

learning_rate = config['model_params']['learning_rate']
batch_size = config['model_params']['batch_size']
num_epochs = config['model_params']['num_epochs']
dropout_rate = config['model_params']['dropout_rate']
optimizer = config['model_params']['optimizer']
loss_function = config['model_params']['loss_function']
metrics = config['model_params']['metrics']
output_shape = config['model_params']['output_shape']

In [5]:
# Prepare data into dataloader
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
])

train_data = ImageFolder(train_dir, transform=transform)
valid_data = ImageFolder(valid_dir, transform=transform)
test_data = ImageFolder(test_dir, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Start with Pre-Trained Models

Let's choose some popular CNN architectures as a starting point

First, we need to load these pretrained models

In [6]:
num_classes = 3
alexnet = modify_model_output('alexnet', num_classes, device)
vgg16 = modify_model_output('vgg16', num_classes, device)
resnet50 = modify_model_output('resnet50', num_classes, device)

## 1. Trying ResNet50

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=0.001, momentum=0.3)
model_name = "ResNet50"
date_time = time.time()
logger.add(f"logs/{model_name}/training_log-{date_time}.log", format="{time} {level} {message}", level="INFO")
train_validate_model(num_epochs=10, model=resnet50, train_loader=train_loader, valid_loader=valid_loader, criterion=criterion, optimizer=optimizer, device=device)

Training Epoch 1/10: 100%|██████████| 5/5 [00:01<00:00,  3.55it/s]


[32m2024-06-23 13:01:10.038[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 1, Train Loss: 1.0873, Train Accuracy: 39.26%[0m


Validation Epoch 1/10: 100%|██████████| 3/3 [00:00<00:00,  6.09it/s]


[32m2024-06-23 13:01:10.531[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 1, Validation Loss: 1.1012, Validation Accuracy: 30.56%[0m


Training Epoch 2/10: 100%|██████████| 5/5 [00:01<00:00,  3.33it/s]


[32m2024-06-23 13:01:12.049[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 2, Train Loss: 1.0838, Train Accuracy: 38.15%[0m


Validation Epoch 2/10: 100%|██████████| 3/3 [00:00<00:00,  6.46it/s]


[32m2024-06-23 13:01:12.519[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 2, Validation Loss: 1.0928, Validation Accuracy: 38.89%[0m


Training Epoch 3/10: 100%|██████████| 5/5 [00:01<00:00,  4.72it/s]


[32m2024-06-23 13:01:13.584[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 3, Train Loss: 1.0776, Train Accuracy: 43.70%[0m


Validation Epoch 3/10: 100%|██████████| 3/3 [00:00<00:00,  6.40it/s]


[32m2024-06-23 13:01:14.063[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 3, Validation Loss: 1.0805, Validation Accuracy: 44.44%[0m


Training Epoch 4/10: 100%|██████████| 5/5 [00:01<00:00,  4.69it/s]


[32m2024-06-23 13:01:15.129[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 4, Train Loss: 1.0736, Train Accuracy: 49.63%[0m


Validation Epoch 4/10: 100%|██████████| 3/3 [00:00<00:00,  5.09it/s]


[32m2024-06-23 13:01:15.735[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 4, Validation Loss: 1.0756, Validation Accuracy: 45.00%[0m


Training Epoch 5/10: 100%|██████████| 5/5 [00:00<00:00,  5.16it/s]


[32m2024-06-23 13:01:16.703[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 5, Train Loss: 1.0591, Train Accuracy: 53.70%[0m


Validation Epoch 5/10: 100%|██████████| 3/3 [00:00<00:00,  5.11it/s]


[32m2024-06-23 13:01:17.306[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 5, Validation Loss: 1.0626, Validation Accuracy: 48.89%[0m


Training Epoch 6/10: 100%|██████████| 5/5 [00:01<00:00,  4.75it/s]


[32m2024-06-23 13:01:18.359[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 6, Train Loss: 1.0426, Train Accuracy: 55.56%[0m


Validation Epoch 6/10: 100%|██████████| 3/3 [00:00<00:00,  6.47it/s]


[32m2024-06-23 13:01:18.822[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 6, Validation Loss: 1.0628, Validation Accuracy: 51.11%[0m


Training Epoch 7/10: 100%|██████████| 5/5 [00:01<00:00,  4.58it/s]


[32m2024-06-23 13:01:19.915[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 7, Train Loss: 1.0445, Train Accuracy: 57.41%[0m


Validation Epoch 7/10: 100%|██████████| 3/3 [00:00<00:00,  6.30it/s]


[32m2024-06-23 13:01:20.392[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 7, Validation Loss: 1.0500, Validation Accuracy: 51.67%[0m


Training Epoch 8/10: 100%|██████████| 5/5 [00:01<00:00,  4.32it/s]


[32m2024-06-23 13:01:21.548[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 8, Train Loss: 1.0301, Train Accuracy: 62.59%[0m


Validation Epoch 8/10: 100%|██████████| 3/3 [00:00<00:00,  6.20it/s]


[32m2024-06-23 13:01:22.031[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 8, Validation Loss: 1.0453, Validation Accuracy: 56.11%[0m


Training Epoch 9/10: 100%|██████████| 5/5 [00:00<00:00,  5.29it/s]


[32m2024-06-23 13:01:22.977[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 9, Train Loss: 1.0370, Train Accuracy: 65.93%[0m


Validation Epoch 9/10: 100%|██████████| 3/3 [00:00<00:00,  5.01it/s]


[32m2024-06-23 13:01:23.576[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 9, Validation Loss: 1.0385, Validation Accuracy: 61.67%[0m


Training Epoch 10/10: 100%|██████████| 5/5 [00:01<00:00,  4.74it/s]


[32m2024-06-23 13:01:24.632[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 10, Train Loss: 1.0305, Train Accuracy: 67.04%[0m


Validation Epoch 10/10: 100%|██████████| 3/3 [00:00<00:00,  5.03it/s]

[32m2024-06-23 13:01:25.228[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 10, Validation Loss: 1.0346, Validation Accuracy: 61.67%[0m





## 2. Trying VGG16

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg16.parameters(), lr=0.001, momentum=0.3)
logger.add("training_log.log", format="{time} {level} {message}", level="INFO")
train_validate_model(num_epochs=20, model=vgg16, train_loader=train_loader, valid_loader=valid_loader, criterion=criterion, optimizer=optimizer, device=device)

Training Epoch 1/20: 100%|██████████| 5/5 [00:01<00:00,  3.66it/s]


[32m2024-06-23 13:01:26.610[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 1, Train Loss: 1.1994, Train Accuracy: 29.26%[0m


Validation Epoch 1/20: 100%|██████████| 3/3 [00:00<00:00,  5.49it/s]


[32m2024-06-23 13:01:27.166[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 1, Validation Loss: 1.0142, Validation Accuracy: 56.11%[0m


Training Epoch 2/20: 100%|██████████| 5/5 [00:01<00:00,  2.86it/s]


[32m2024-06-23 13:01:28.914[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 2, Train Loss: 0.9591, Train Accuracy: 49.63%[0m


Validation Epoch 2/20: 100%|██████████| 3/3 [00:00<00:00,  4.95it/s]


[32m2024-06-23 13:01:29.521[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 2, Validation Loss: 0.8956, Validation Accuracy: 71.67%[0m


Training Epoch 3/20: 100%|██████████| 5/5 [00:01<00:00,  2.98it/s]


[32m2024-06-23 13:01:31.201[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 3, Train Loss: 0.8845, Train Accuracy: 59.26%[0m


Validation Epoch 3/20: 100%|██████████| 3/3 [00:00<00:00,  5.33it/s]


[32m2024-06-23 13:01:31.781[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 3, Validation Loss: 0.7897, Validation Accuracy: 77.78%[0m


Training Epoch 4/20: 100%|██████████| 5/5 [00:01<00:00,  3.68it/s]


[32m2024-06-23 13:01:33.149[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 4, Train Loss: 0.7950, Train Accuracy: 65.93%[0m


Validation Epoch 4/20: 100%|██████████| 3/3 [00:00<00:00,  3.12it/s]


[32m2024-06-23 13:01:34.112[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 4, Validation Loss: 0.7085, Validation Accuracy: 79.44%[0m


Training Epoch 5/20: 100%|██████████| 5/5 [00:01<00:00,  3.78it/s]


[32m2024-06-23 13:01:35.437[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 5, Train Loss: 0.6996, Train Accuracy: 77.41%[0m


Validation Epoch 5/20: 100%|██████████| 3/3 [00:01<00:00,  3.00it/s]


[32m2024-06-23 13:01:36.453[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 5, Validation Loss: 0.6059, Validation Accuracy: 84.44%[0m


Training Epoch 6/20: 100%|██████████| 5/5 [00:01<00:00,  3.76it/s]


[32m2024-06-23 13:01:37.798[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 6, Train Loss: 0.5976, Train Accuracy: 80.37%[0m


Validation Epoch 6/20: 100%|██████████| 3/3 [00:00<00:00,  5.14it/s]


[32m2024-06-23 13:01:38.383[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 6, Validation Loss: 0.5509, Validation Accuracy: 83.89%[0m


Training Epoch 7/20: 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]


[32m2024-06-23 13:01:40.087[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 7, Train Loss: 0.5455, Train Accuracy: 81.48%[0m


Validation Epoch 7/20: 100%|██████████| 3/3 [00:00<00:00,  4.92it/s]


[32m2024-06-23 13:01:40.704[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 7, Validation Loss: 0.4928, Validation Accuracy: 85.00%[0m


Training Epoch 8/20: 100%|██████████| 5/5 [00:01<00:00,  2.85it/s]


[32m2024-06-23 13:01:42.455[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 8, Train Loss: 0.4698, Train Accuracy: 84.07%[0m


Validation Epoch 8/20: 100%|██████████| 3/3 [00:00<00:00,  5.36it/s]


[32m2024-06-23 13:01:43.015[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 8, Validation Loss: 0.4452, Validation Accuracy: 86.67%[0m


Training Epoch 9/20: 100%|██████████| 5/5 [00:01<00:00,  3.82it/s]


[32m2024-06-23 13:01:44.325[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 9, Train Loss: 0.4488, Train Accuracy: 84.81%[0m


Validation Epoch 9/20: 100%|██████████| 3/3 [00:00<00:00,  3.03it/s]


[32m2024-06-23 13:01:45.315[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 9, Validation Loss: 0.4193, Validation Accuracy: 85.56%[0m


Training Epoch 10/20: 100%|██████████| 5/5 [00:01<00:00,  3.70it/s]


[32m2024-06-23 13:01:46.678[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 10, Train Loss: 0.3736, Train Accuracy: 87.78%[0m


Validation Epoch 10/20: 100%|██████████| 3/3 [00:00<00:00,  3.14it/s]


[32m2024-06-23 13:01:47.649[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 10, Validation Loss: 0.3746, Validation Accuracy: 87.22%[0m


Training Epoch 11/20: 100%|██████████| 5/5 [00:01<00:00,  3.84it/s]


[32m2024-06-23 13:01:48.954[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 11, Train Loss: 0.3636, Train Accuracy: 86.67%[0m


Validation Epoch 11/20: 100%|██████████| 3/3 [00:00<00:00,  5.17it/s]


[32m2024-06-23 13:01:49.534[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 11, Validation Loss: 0.3565, Validation Accuracy: 87.22%[0m


Training Epoch 12/20: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s]


[32m2024-06-23 13:01:51.272[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 12, Train Loss: 0.2982, Train Accuracy: 90.37%[0m


Validation Epoch 12/20: 100%|██████████| 3/3 [00:00<00:00,  5.06it/s]


[32m2024-06-23 13:01:51.875[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 12, Validation Loss: 0.3501, Validation Accuracy: 88.89%[0m


Training Epoch 13/20: 100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


[32m2024-06-23 13:01:53.594[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 13, Train Loss: 0.3279, Train Accuracy: 87.04%[0m


Validation Epoch 13/20: 100%|██████████| 3/3 [00:00<00:00,  5.53it/s]


[32m2024-06-23 13:01:54.136[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 13, Validation Loss: 0.3484, Validation Accuracy: 88.33%[0m


Training Epoch 14/20: 100%|██████████| 5/5 [00:01<00:00,  3.81it/s]


[32m2024-06-23 13:01:55.469[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 14, Train Loss: 0.2672, Train Accuracy: 89.63%[0m


Validation Epoch 14/20: 100%|██████████| 3/3 [00:00<00:00,  3.18it/s]


[32m2024-06-23 13:01:56.415[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 14, Validation Loss: 0.3257, Validation Accuracy: 90.00%[0m


Training Epoch 15/20: 100%|██████████| 5/5 [00:01<00:00,  3.79it/s]


[32m2024-06-23 13:01:57.735[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 15, Train Loss: 0.2317, Train Accuracy: 91.48%[0m


Validation Epoch 15/20: 100%|██████████| 3/3 [00:00<00:00,  3.04it/s]


[32m2024-06-23 13:01:58.723[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 15, Validation Loss: 0.3179, Validation Accuracy: 89.44%[0m


Training Epoch 16/20: 100%|██████████| 5/5 [00:01<00:00,  3.75it/s]


[32m2024-06-23 13:02:00.066[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 16, Train Loss: 0.2506, Train Accuracy: 89.26%[0m


Validation Epoch 16/20: 100%|██████████| 3/3 [00:00<00:00,  5.43it/s]


[32m2024-06-23 13:02:00.620[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 16, Validation Loss: 0.3040, Validation Accuracy: 88.33%[0m


Training Epoch 17/20: 100%|██████████| 5/5 [00:01<00:00,  2.97it/s]


[32m2024-06-23 13:02:02.304[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 17, Train Loss: 0.1940, Train Accuracy: 93.33%[0m


Validation Epoch 17/20: 100%|██████████| 3/3 [00:00<00:00,  5.32it/s]


[32m2024-06-23 13:02:02.868[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 17, Validation Loss: 0.2925, Validation Accuracy: 90.56%[0m


Training Epoch 18/20: 100%|██████████| 5/5 [00:01<00:00,  2.92it/s]


[32m2024-06-23 13:02:04.578[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 18, Train Loss: 0.2179, Train Accuracy: 93.70%[0m


Validation Epoch 18/20: 100%|██████████| 3/3 [00:00<00:00,  5.62it/s]


[32m2024-06-23 13:02:05.111[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 18, Validation Loss: 0.2922, Validation Accuracy: 88.89%[0m


Training Epoch 19/20: 100%|██████████| 5/5 [00:01<00:00,  3.84it/s]


[32m2024-06-23 13:02:06.414[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 19, Train Loss: 0.1972, Train Accuracy: 92.59%[0m


Validation Epoch 19/20: 100%|██████████| 3/3 [00:01<00:00,  2.88it/s]


[32m2024-06-23 13:02:07.471[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 19, Validation Loss: 0.2908, Validation Accuracy: 88.33%[0m


Training Epoch 20/20: 100%|██████████| 5/5 [00:01<00:00,  3.54it/s]


[32m2024-06-23 13:02:08.901[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m88[0m - [1mEpoch 20, Train Loss: 0.1543, Train Accuracy: 95.56%[0m


Validation Epoch 20/20: 100%|██████████| 3/3 [00:01<00:00,  2.69it/s]

[32m2024-06-23 13:02:10.014[0m | [1mINFO    [0m | [36mutils[0m:[36mtrain_validate_model[0m:[36m111[0m - [1mEpoch 20, Validation Loss: 0.2889, Validation Accuracy: 89.44%[0m





## 3. Trying AlexNet

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(alexnet.parameters(), lr=0.001, momentum=0.3)
train_validate_model(num_epochs=10, model=alexnet, train_loader=train_loader, valid_loader=valid_loader, criterion=criterion, optimizer=optimizer, device=device)

AttributeError: 'int' object has no attribute '_name'

In [9]:
import torchvision.models as models
vgg11 = models.vgg11(weights='DEFAULT')
num_features = vgg11.classifier[6].in_features
vgg11.classifier[6] = nn.Linear(num_features, 3)

# Training/Validation Loop
train_validate_model(num_epochs=10, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion, optimizer=optimizer, device=device)

NameError: name 'model' is not defined

In [25]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.3)

train_validate_model(num_epochs=10, train_loader=train_loader, valid_loader=valid_loader, model=model, criterion=criterion, optimizer=optimizer, device=device)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [None]:
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import vgg16, VGG16_Weights
from torchvision.models import alexnet, AlexNet_Weights
from torchvision.models import incep

In [24]:
#model = resnet50(weights=ResNet50_Weights.DEFAULT).to(device)
backbone = model.features
#backbone[-1].out_channels
backbone

AttributeError: 'ResNet' object has no attribute 'features'

In [28]:
num_classes = 3
alexnet = modify_model_output('alexnet', num_classes)
vgg16 = modify_model_output('vgg16', num_classes)
resnet50 = modify_model_output('resnet50', num_classes)

Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to C:\Users\MLDev/.cache\torch\hub\checkpoints\alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:07<00:00, 32.5MB/s] 
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\MLDev/.cache\torch\hub\checkpoints\vgg16-397923af.pth
100%|██████████| 528M/528M [00:22<00:00, 25.0MB/s] 
Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to C:\Users\MLDev/.cache\torch\hub\checkpoints\inception_v3_google-0cc3c7bd.pth
100%|██████████| 104M/104M [00:03<00:00, 33.5MB/s] 

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 




In [None]:
from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights=ResNet50_Weights.DEFAULT).to(device)
train_validate_model(num_epochs=100, train_loader=train_loader, valid_loader=valid_loader, model=model.softmax(3), criterion=criterion, optimizer=optimizer, device=device)
#num_features = resnet50.classifier[6].in_features
#vgg11.classifier[6] = nn.Linear(num_features, 3)

In [26]:
model = models.vgg16(weights='DEFAULT')

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\MLDev/.cache\torch\hub\checkpoints\vgg16-397923af.pth
 27%|██▋       | 144M/528M [00:05<00:14, 27.7MB/s] 


KeyboardInterrupt: 