In [1]:
import os
#from helpers import *

from jlib.cifar_preprocessing import get_cifar_loaders, delete_deletables
from jlib.resnet import ResNet, ConvParams
import torch
from torch import nn
import matplotlib.pyplot as plt
device = 'cuda'

"""
sudo fuser -v -k /usr/lib/wsl/drivers/nvhm.inf_amd64_5c197d2d97068bef/*
"""
    




def train_and_plot(model, train, val, title, min_accuracy=0.65):
    print(f"Model: {title}")
    print("-"*50)
    print(f"Num Params: {sum(p.numel() for p in model.parameters())}")
    model.train_model(
        epochs=50,
        train_loader=train,
        val_loader=val,
        loss_fn=nn.CrossEntropyLoss(),
        optimizer=torch.optim.Adam,
        optimizer_args=[],
        optimizer_kwargs={'lr': 1e-4},
        print_epoch=1,
        header_epoch=10,
        sched_factor=0.1,
        sched_patience=5,
        min_accuracy = 1,
        max_negative_diff_count = 7
    )
    torch.save(model, f'models/{title}.pth')
    model.plot_training(f"Training {title}")
    plt.savefig(f'figures/{title}.png')
    
architecture = {
    'in_chan': 3,
    'in_dim': (32, 32),
    'first_conv': ConvParams(
        kernel=3,
        out_chan=64,
        stride=1,
        padding='same',
    ),
    'block_params': [
        ConvParams(
            kernel=3,
            out_chan=64,
            stride=1,
            padding='same',
        ),
        ConvParams(
            kernel=3,
            out_chan=128,
            stride=1,
            padding='same',
        ),
        ConvParams(
            kernel=3,
            out_chan=256,
            stride=1,
            padding='same',
        ),
        ConvParams(
            kernel=3,
            out_chan=512,
            stride=1,
            padding='same',
        ),
    ],
    'fc_params': [
        2048,
        2048,
    ],
}

cifar_100_train_loader, cifar_100_val_loader = get_cifar_loaders(
    is_cifar10=True,
    train_batch_size=512,
    val_batch_size=2048,
    train_workers=6,
    train_cpu_prefetch=10,
    train_gpu_prefetch=10,
    val_workers=2,
    val_cpu_prefetch=2,
    val_gpu_prefetch=2,
)


torch.cuda.empty_cache()
Res_100_dp = ResNet(
    num_classes=100,
    dropout = 0.5,
    **architecture
).to(device)
train_and_plot(Res_100_dp, cifar_100_train_loader, cifar_100_val_loader, "ResNet CIFAR100 DO")
delete_deletables([Res_100_dp])


torch.cuda.empty_cache()
Res_100_ndp = ResNet(
    num_classes=100,
    dropout = 0,
    **architecture
).to(device)

train_and_plot(Res_100_ndp, cifar_100_train_loader, cifar_100_val_loader, "ResNet CIFAR100 NO DO")
delete_deletables([Res_100_ndp])

delete_deletables([cifar_100_train_loader, cifar_100_val_loader])

cifar_10_train_loader, cifar_10_val_loader = get_cifar_loaders(
    is_cifar10=True,
    train_batch_size=512,
    val_batch_size=512,
    train_workers=6,
    train_cpu_prefetch=10,
    train_gpu_prefetch=10,
    val_workers=2,
    val_cpu_prefetch=2,
    val_gpu_prefetch=2,
)

torch.cuda.empty_cache()
Res_10_dp = ResNet(
    num_classes=10,
    dropout = 0.5,
    **architecture
).to(device)

train_and_plot(Res_10_dp, cifar_10_train_loader, cifar_10_val_loader, "ResNet CIFAR10 DO")
delete_deletables([Res_10_dp])

torch.cuda.empty_cache()
Res_10_ndp = ResNet(
    num_classes=10,
    dropout = 0,
    **architecture
).to(device)

train_and_plot(Res_10_ndp, cifar_10_train_loader, cifar_10_val_loader, "ResNet CIFAR10 NO DO")
del Res_10_ndp


  from .autonotebook import tqdm as notebook_tqdm


Begin init data loader
Batch Size: 6.0 MiB
Data Size: 292.96875 GiB
Data Loader init time: 3.487697 s
Begin init fetcher
Fetcher init time: 3.568801 s
Begin init data loader
Batch Size: 24.0 MiB
Data Size: 234.375 GiB
Data Loader init time: 1.550779 s
Begin init fetcher
Fetcher init time: 1.597170 s


TypeError: jlib.resnet.ResNet() got multiple values for keyword argument 'dropout'

: 