In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchvision.transforms.functional as F
from PIL import Image
from timm.models.layers.std_conv import StdConv2dSame
from timm.models.resnetv2 import ResNetV2
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.io import read_image
from torch.utils.data import DataLoader
from models import PetPopularityRegression


In [2]:
class PetDataset(Dataset):
    def __init__(
        self, annotations_file, img_dir, transform=None, target_transform=None
    ):
        self.pet_data = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.pet_data)

    def __getitem__(self, idx):
        data = list(self.pet_data.iloc[idx])

        img_path = os.path.join(self.img_dir, data[0] + ".jpg")
        image = Image.open(img_path)

        label = data[-1]

        if self.transform:
            image = self.transform(image)

        return image, torch.Tensor(data[1:-1]), label


class SquarePad:
    def __call__(self, image):
        w, h = image.size
        max_wh = np.max([w, h])
        hp = int((max_wh - w) / 2)
        vp = int((max_wh - h) / 2)
        padding = (hp, vp, hp, vp)
        return F.pad(image, padding, 0, "constant")


transform = transforms.Compose(
    [
        SquarePad(),
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)


In [3]:
def collate_fn(data):
    images, attributes, targets = zip(*data)
    return (torch.stack(images), torch.stack(attributes)), torch.Tensor(targets)


In [4]:
pet_dataset = PetDataset(
    "./raw_data/train.csv", "./raw_data/train", transform=transform
)
loader = DataLoader(pet_dataset, collate_fn=collate_fn, batch_size=4)


In [5]:
(image_batch, attribute_batch), targets = next(iter(loader))


In [6]:
model = PetPopularityRegression()


  nn.init.xavier_uniform(self.linear.weight)


In [7]:
model((image_batch, attribute_batch))

tensor([[54.5662],
        [57.3383],
        [66.4959],
        [42.3373]], grad_fn=<MulBackward0>)