In [2]:
#|export
from fastai.vision.all import *
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [15]:
#|export
#path = untar_data(URLs.IMAGENETTE) # this downloads to ~/.fastai/data/imagenette2
models = {
    'resnet18': resnet18,
    'resnet34': resnet34,
    'resnet50': resnet50,
}
def train(model_name, epochs=1):
    if model_name not in models:
        raise ValueError(f"Model name '{model_name}' is not supported. Choose from: {list(models.keys())}")
 
    model = models[model_name]

    path = Path('data/imagenette2')
    # ensure that transformed data is not used as well
    train_fnames = get_image_files(path/'train')
    val_fnames = get_image_files(path/'val')
    fnames = train_fnames + val_fnames

    # load the data
    dls = ImageDataLoaders.from_path_func(
        path, fnames, label_func=parent_label, item_tfms=Resize(224))
    
    # setup model with pretrained weights and data
    learn = vision_learner(dls, model, metrics=accuracy)

    # train with model frozen except for the last layer
    learn.fine_tune(epochs)

    # save the model to models folder
    save_path = f'models/{model_name}.pt'
    torch.save(learn.model, save_path)

In [16]:
train('resnet18')

epoch,train_loss,valid_loss,accuracy,time
0,0.216833,0.05047,0.985063,01:53


epoch,train_loss,valid_loss,accuracy,time
0,0.102142,0.047981,0.986931,02:14


In [17]:
train('resnet34')

epoch,train_loss,valid_loss,accuracy,time
0,0.174275,0.025039,0.991785,02:21


epoch,train_loss,valid_loss,accuracy,time
0,0.104469,0.045552,0.985437,03:01


In [18]:
train('resnet50')

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /Users/juhokokko/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 35.9MB/s]


epoch,train_loss,valid_loss,accuracy,time
0,0.16385,0.029039,0.992158,03:46


epoch,train_loss,valid_loss,accuracy,time
0,0.057236,0.034291,0.991038,04:30


In [4]:
#|export
def load(model_name):
    model = torch.load(f'models/{model_name}.pt')
    model.to(device)
    return model

In [7]:
input_data = torch.randn(1, 3, 224, 224)

In [8]:
# test that the model is loaded correctly
model = load('resnet18')
model.eval()
with torch.no_grad(): 
    output = model(input_data)

assert output.shape == (1, 10), f"Expected output shape to be (1, 10), but got {output.shape}"

In [37]:
# test that the model is loaded correctly
model = load('resnet34')
model.eval()
with torch.no_grad(): 
    output = model(input_data)

assert output.shape == (1, 10), f"Expected output shape to be (1, 10), but got {output.shape}"

In [40]:
# test that the model is loaded correctly
model = load('resnet50')
model.eval()
with torch.no_grad(): 
    output = model(input_data)

assert output.shape == (1, 10), f"Expected output shape to be (1, 10), but got {output.shape}"

In [None]:
#current_path = Path('.')
#learn.path = current_path
#learn.model_dir = 'models'
#learn.save('resnet18')