In [1]:
import os

import cv2
import timm
import torch
from torch.nn import MSELoss
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def determine_device():
    if torch.cuda.is_available():
        return 'cuda'
    elif torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'


device = determine_device()
print(f'Device is {device}')

Device is mps


In [3]:
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=1)
model = model.to(device)
from timm.data import resolve_data_config, create_transform

config = resolve_data_config({}, model=model)
print(config)
transform = create_transform(**config)
print(transform)

{'input_size': (3, 224, 224), 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'crop_pct': 0.9, 'crop_mode': 'center'}
Compose(
    Resize(size=248, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    MaybeToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
)


In [7]:
import urllib
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

config = resolve_data_config({}, model=model)
transform = create_transform(**config)

url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
urllib.request.urlretrieve(url, filename)
img = Image.open(filename).convert('RGB')
tensor = transform(img).unsqueeze(0)
print(tensor.shape)

tensor = tensor.to(device)
model(tensor)

torch.Size([1, 3, 224, 224])


tensor([[0.2534]], device='mps:0', grad_fn=<LinearBackward0>)

In [8]:
class CustomImageDataset(Dataset):
    def __init__(self, transform=None):
        outputs_dir = os.path.join(os.getcwd(), "outputs")
        dataset_file = os.path.join(outputs_dir, 'dataset', 'dataset.pt')
        x, y, sigma = torch.load(dataset_file)
        self.x = x
        self.y = y.unsqueeze(dim=-1)
        self.sigma = sigma
        self.transform = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        # TODO: data transformers
        image = self.x[idx]
        return image if self.transform is None else self.transform(image), self.y[idx]


dataset = CustomImageDataset(transform=transform)
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

In [9]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

In [10]:
index = 6
train_features, train_labels = next(iter(train_dataloader))
outputs_dir = os.path.join(os.getcwd(), "outputs")
cv2.imwrite(os.path.join(outputs_dir, f"output_{index}.jpg"), train_features[index].numpy().transpose((1, 2, 0)))

[ WARN:0@56.054] global loadsave.cpp:1063 imwrite_ Unsupported depth image for selected encoder is fallbacked to CV_8U.


True

In [11]:
train_features, train_labels = next(iter(train_dataloader))
train_features = train_features.to(device)
train_labels = train_labels.to(device)
model(train_features)


tensor([[-2.9417],
        [-2.9615],
        [-3.4184],
        [-2.7525],
        [-3.4995],
        [-3.3067],
        [-3.2270],
        [-2.8109],
        [-3.4074],
        [-3.0861],
        [-3.2146],
        [-3.0557],
        [-2.9839],
        [-2.7279],
        [-3.1739],
        [-2.4951],
        [-2.4629],
        [-3.0998],
        [-2.9713],
        [-2.5996],
        [-2.5224],
        [-2.8620],
        [-3.0789],
        [-3.2175],
        [-3.0509],
        [-3.2535],
        [-2.8265],
        [-2.5304],
        [-2.5634],
        [-3.0118],
        [-3.1917],
        [-3.3138],
        [-3.2846],
        [-2.8126],
        [-3.1080],
        [-3.2125],
        [-2.2847],
        [-3.0296],
        [-3.2456],
        [-2.9109],
        [-2.4987],
        [-2.8149],
        [-3.1199],
        [-3.0501],
        [-2.7353],
        [-3.4728],
        [-2.7140],
        [-2.5155],
        [-2.7812],
        [-3.2834],
        [-3.1311],
        [-2.9152],
        [-2.

In [12]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
criterion = MSELoss()

In [13]:
model.train()
for x, y in train_dataloader:
    x, y = x.to(device), torch.reshape(y, shape=(64, 1)).to(device)
    outputs = model(x)
    loss = criterion(outputs, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f'Loss is {loss}')

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.