In [9]:
#|export
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai, push_to_hub_fastai
repo ='ninjalabo/'

In [17]:
# You must login to hugging face to upload models to ninjalabo
# from huggingface_hub import notebook_login
# notebook_login()

In [2]:
#|export
models = {
    'resnet18': resnet18,
    'resnet34': resnet34,
    'resnet50': resnet50,
}

def train(model_name, epochs=1):
    if model_name not in models:
        raise ValueError(f"Model name '{model_name}' is not supported. Choose from: {list(models.keys())}")
 
    model = models[model_name]
    path = untar_data(URLs.IMAGENETTE_320,data=Path.cwd()/'data')
    dls = ImageDataLoaders.from_folder(path, valid='val', item_tfms=Resize(224), batch_tfms=Normalize.from_stats(*imagenet_stats),)
    learn = vision_learner(dls, model, metrics=accuracy, pretrained=True)
    learn.fine_tune(epochs)

    return learn

def train_all(epochs=1):
    for model_name in models:
        learn = train(model_name, epochs)
        push_to_hub_fastai(learner=learn, repo_id=repo + model_name)

In [97]:
train_all()

epoch,train_loss,valid_loss,accuracy,time
0,0.239424,0.060668,0.980637,01:23


epoch,train_loss,valid_loss,accuracy,time
0,0.09718,0.057779,0.98293,01:46


model.pkl: 100%|██████████| 47.1M/47.1M [00:04<00:00, 11.5MB/s]


epoch,train_loss,valid_loss,accuracy,time
0,0.195063,0.041879,0.987006,01:52


epoch,train_loss,valid_loss,accuracy,time
0,0.081703,0.047114,0.987771,02:29


model.pkl: 100%|██████████| 87.6M/87.6M [00:06<00:00, 13.5MB/s]


epoch,train_loss,valid_loss,accuracy,time
0,0.182609,0.051196,0.986242,03:22


epoch,train_loss,valid_loss,accuracy,time
0,0.058112,0.031299,0.991847,04:14


model.pkl: 100%|██████████| 103M/103M [00:07<00:00, 14.3MB/s] 


In [18]:
train('resnet50', epochs=3)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /Users/juhokokko/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 35.9MB/s]


epoch,train_loss,valid_loss,accuracy,time
0,0.16385,0.029039,0.992158,03:46


epoch,train_loss,valid_loss,accuracy,time
0,0.057236,0.034291,0.991038,04:30


In [3]:
#|export
def load(model_name):
    learn = from_pretrained_fastai(repo_id=repo + model_name)
    return learn

In [4]:
import torch
input_data = torch.randn(1, 3, 224, 224)

In [10]:
# test that the model is loaded correctly
model = load('resnet18').model
model.eval()
with torch.no_grad(): 
    output = model(input_data)
    print(output)

assert output.shape == (1, 10), f"Expected output shape to be (1, 10), but got {output.shape}"

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([[ 0.2305, -3.9527, -0.4192, -0.8148,  3.0921, -1.3428, -0.8843, -4.9920,
          3.0464,  4.2255]])


In [11]:
# test that the model is loaded correctly
model = load('resnet34').model
model.eval()
with torch.no_grad(): 
    output = model(input_data)
    print(output)

assert output.shape == (1, 10), f"Expected output shape to be (1, 10), but got {output.shape}"

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/768 [00:00<?, ?B/s]

pyproject.toml:   0%|          | 0.00/162 [00:00<?, ?B/s]

model.pkl:   0%|          | 0.00/87.6M [00:00<?, ?B/s]

tensor([[-1.1578, -1.6175,  0.6183, -1.0238,  3.4288, -0.2779, -1.9741, -4.9624,
          4.7883,  2.1142]])


In [12]:
# test that the model is loaded correctly
model = load('resnet50').model
model.eval()
with torch.no_grad(): 
    output = model(input_data)
    print(output)

assert output.shape == (1, 10), f"Expected output shape to be (1, 10), but got {output.shape}"

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

pyproject.toml:   0%|          | 0.00/162 [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/768 [00:00<?, ?B/s]

model.pkl:   0%|          | 0.00/103M [00:00<?, ?B/s]

tensor([[-1.6933, -2.6376, -0.0326, -0.3681,  3.9738, -1.3005, -2.0533, -3.8796,
          6.4267,  5.8173]])
