see:
https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py
https://pytorch.org/hub/pytorch_vision_mobilenet_v2/

In [None]:
import torch
import torchvision
import numpy as np
from glob import glob
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from os import path
from PIL import Image

mobile_net = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
embedding_model = torch.nn.Sequential(*list(mobile_net.children()))[:-1]
embedding_model.eval()

classifier = torch.nn.Sequential(
    torch.nn.Linear(1280, out_features=10, bias=True),
    torch.nn.ReLU(),
    torch.nn.Linear(10  , out_features=3, bias=True),
    torch.nn.Softmax(dim=1)
)


In [None]:
class SteerDataSet(Dataset):
    
    def __init__(self,root_folder,img_ext = ".jpg" , transform=None):
        self.root_folder = root_folder
        self.transform = transform        
        self.img_ext = img_ext        
        self.filenames = glob(path.join(self.root_folder,"*" + self.img_ext))            
        self.totensor = transforms.ToTensor()
        
    def __len__(self):        
        return len(self.filenames)
    
    def __getitem__(self,idx):
        f = self.filenames[idx]        
        img = Image.open(f)
        
        if self.transform == None:
            img = self.totensor(img)
        else:
            img = self.transform(img)   
        
        steering = f.split("/")[-1].split(self.img_ext)[0][6:]
        steering = np.float32(steering)

        if steering > 0:
            simple_steering = 0
        elif steering < 0:
            simple_steering = 1
        else:
            simple_steering = 2
    
        sample = {"image":img , "steering":steering, "simple_steering": simple_steering}        
        
        return sample

In [None]:
BATCH_SIZE = 16

preprocess = transforms.Compose([
    transforms.Resize(256), # TODO downsizing
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = SteerDataSet("data1",".jpg", preprocess)
trainloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.001, momentum=0.9)


In [None]:
PRINT_FREQ = 20

classifier.train()
running_loss = 0.0
running_acc = 0.0

for epoch in range(4):
    running_loss = 0.0
    running_acc = 0.0
    for i, s in enumerate(trainloader):
        # bring data in right format
        data = s['image']
        label = s['simple_steering']
        # label_onehot = torch.nn.functional.one_hot(label, num_classes = 3)

        # apply embedding model
        with torch.no_grad():
            embeddings = embedding_model(data)
            embeddings2 = torch.nn.functional.adaptive_avg_pool2d(embeddings, (1, 1))
            embeddings3 = torch.flatten(embeddings2, 1)
        
        # backprop classifier
        optimizer.zero_grad()
        output = classifier(embeddings3)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        # update train-accuracy
        pred = torch.argmax(output, dim=1)
        this_acc = sum( [a.item() == b.item()  for (a, b) in zip(pred, label)] ) / BATCH_SIZE

        running_loss = running_loss + loss.item()
        running_acc = running_acc + this_acc

        if i % PRINT_FREQ == 0:
            print(f"epoch: {epoch},\t item:{i},\t loss:{running_loss:.3f},\t acc:{running_acc/PRINT_FREQ:.2f}")
            running_loss = 0.0
            running_acc = 0.0

In [None]:
torch.save(classifier.state_dict(), 'class_params.pt')