In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import timm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.001
DATA_DIR = "data/"

In [4]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.labels = {}

        # Read the labels.txt and create a mapping from image name to a list of labels
        with open(os.path.join(data_dir, "labels.txt"), "r") as f:
            for line in f:
                image_name = line.strip().split(",")[0]
                self.labels[image_name] = [list(map(float, line.strip().split(",")[1:]))]
        #         image_name, relative_x, relative_y, action = line.strip().split(",")
        #         if image_name not in self.labels:
        #             self.labels[image_name] = []
        #         self.labels[image_name].append([float(relative_x), float(relative_y), float(action)])
        
        self.images = list(self.labels.keys())

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        image_path = os.path.join(self.data_dir, img_name)
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label_matrix = torch.tensor(self.labels[img_name])  # This will be of size 10x3
        return image, label_matrix

In [5]:
# Transform
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),  # Appropriate for EfficientNet-B0
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [6]:
dataset = CustomDataset(data_dir=DATA_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [7]:
# dataset[0] will return a tuple of image and label matrix
dataset[0][0] # image
dataset[0][1].shape # label

torch.Size([1, 40])

In [8]:
class EfficientNet(nn.Module):
    def __init__(self, out_size: int = 40, out_scale: float = 1, freeze_base: bool = True):
        super().__init__()
        self.out_scale = out_scale
        
        # Load the pre-trained EfficientNet model (for this example, we'll use efficientnet_b0)
        eff_net = timm.create_model('efficientnet_b0', pretrained=True)

        if freeze_base:
            for param in eff_net.parameters():
                param.requires_grad = False

        # replace the last fully connected layer
        num_features = eff_net.classifier.in_features
        eff_net.classifier = nn.Linear(num_features, out_size)
        self.model = nn.Sequential(eff_net, nn.Tanh())

    def forward(self, x):
        x = self.model(x) * self.out_scale  # allows predicting up to out_scale meters away
        # return x.view(-1, 10, 4)  # reshaping the output to [batch_size, 10, 3]
        return x

    def get_transforms(self):
        return transforms.Compose(
            [
                transforms.ConvertImageDtype(torch.float),
                transforms.Resize((224, 224), antialias=True),  # appropriate for efficientnet_b0
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),  # imagenet norms
            ]
        )

In [9]:
model = EfficientNet()  # x and y as outputs
criterion = torch.nn.MSELoss()  # regression problem, so use Mean Squared Error
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [10]:
for epoch in range(EPOCHS):
    for images, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item():.4f}")

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/10, Loss: 18.1916
Epoch 2/10, Loss: 17.0890
Epoch 3/10, Loss: 16.1286
Epoch 4/10, Loss: 15.3496
Epoch 5/10, Loss: 14.7549
Epoch 6/10, Loss: 14.3199
Epoch 7/10, Loss: 14.0093
Epoch 8/10, Loss: 13.7896
Epoch 9/10, Loss: 13.6344
Epoch 10/10, Loss: 13.5245


visualize

In [11]:
out = model(dataset[0][0].unsqueeze(0))
# out.shape
out

tensor([[ 0.9585,  0.9554,  0.1230,  0.1981,  0.9451,  0.9549,  0.2526, -0.0923,
          0.9668,  0.9543,  0.1483,  0.0787,  0.9513,  0.9588,  0.1256, -0.0179,
          0.9659,  0.9363,  0.1338,  0.3259,  0.9575,  0.9555,  0.1438,  0.1734,
          0.9655,  0.9608,  0.7731,  0.0458,  0.9669,  0.9583,  0.4288,  0.3780,
          0.9609,  0.9579,  0.1944,  0.1321,  0.9639,  0.9540,  0.1725, -0.0571]],
       grad_fn=<MulBackward0>)

In [51]:
# torch.save(model.state_dict(), "efficientnet_model.pth")
import torch.onnx
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "efficientnet.onnx")

verbose: False, log level: Level.ERROR

