In [1]:
#|export
from fastai.vision.all import *
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [11]:
#|export
#path = untar_data(URLs.IMAGENETTE) # this downloads to ~/.fastai/data/imagenette2
models = {
    'resnet18': resnet18,
    'resnet34': resnet34,
    'resnet50': resnet50,
}
def train(model_name, epochs=1):
    if model_name not in models:
        raise ValueError(f"Model name '{model_name}' is not supported. Choose from: {list(models.keys())}")
 
    model = models[model_name]

    path = Path('data/imagenette2')
    # ensure that transformed data is not used as well
    train_fnames = get_image_files(path/'train')
    val_fnames = get_image_files(path/'val')
    fnames = train_fnames + val_fnames

    # load the data
    dls = ImageDataLoaders.from_path_func(
        path, fnames, label_func=parent_label, item_tfms=Resize(224))
    
    # setup model with pretrained weights and data
    learn = vision_learner(dls, model, metrics=CrossEntropyLossFlat(), 
                           pretrained=True)

    # train with model frozen except for the last layer
    learn.fine_tune(epochs)

    # save the model to models folder
    models_dir = Path('models')
    # create the models directory if it doesn't exist
    models_dir.mkdir(parents=True, exist_ok=True)
    save_path = models_dir / f'{model_name}.pt'
    torch.save(learn.model, save_path)

In [None]:
train('resnet34', epochs=3)

epoch,train_loss,valid_loss,accuracy,time
0,0.174275,0.025039,0.991785,02:21


epoch,train_loss,valid_loss,accuracy,time
0,0.104469,0.045552,0.985437,03:01


In [18]:
train('resnet50', epochs=3)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /Users/juhokokko/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 35.9MB/s]


epoch,train_loss,valid_loss,accuracy,time
0,0.16385,0.029039,0.992158,03:46


epoch,train_loss,valid_loss,accuracy,time
0,0.057236,0.034291,0.991038,04:30


In [12]:
#|export
def load(model_name):
    model = torch.load(f'models/{model_name}.pt', map_location=device)
    model.eval()
    return model

In [13]:
input_data = torch.randn(1, 3, 224, 224)

In [18]:
# test that the model is loaded correctly
model = load('resnet18')
model.eval()
with torch.no_grad(): 
    output = model(input_data)
    print(output)

assert output.shape == (1, 10), f"Expected output shape to be (1, 10), but got {output.shape}"

tensor([[-2.1822,  2.4006,  0.1702, -3.8383,  3.1280, -1.4197, -2.2073, -2.5701,
          5.8010,  1.5930]])


In [19]:
# test that the model is loaded correctly
model = load('resnet34')
with torch.no_grad(): 
    output = model(input_data)
    print(output)

assert output.shape == (1, 10), f"Expected output shape to be (1, 10), but got {output.shape}"

tensor([[-3.0101,  1.5921,  1.6258, -1.2849,  1.5776,  0.2059, -2.4348, -3.6694,
          4.9993,  0.8250]])


In [67]:
# test that the model is loaded correctly
model = load('resnet50')
with torch.no_grad(): 
    output = model(input_data)
    print(output)

assert output.shape == (1, 10), f"Expected output shape to be (1, 10), but got {output.shape}"

tensor([[-1.1353, -0.2133,  2.9741,  1.9240,  2.2523,  1.4309, -2.5617,  0.1402,
          1.6941,  2.3228]])


In [None]:
#current_path = Path('.')
#learn.path = current_path
#learn.model_dir = 'models'
#learn.save('resnet18')

In [26]:
model = load('resnet18')
model

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  