In [1]:
import os

import cv2
import timm
import torch
from torch.nn import MSELoss
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CustomImageDataset(Dataset):
    def __init__(self):
        outputs_dir = os.path.join(os.getcwd(), "outputs")
        dataset_file = os.path.join(outputs_dir, 'dataset', 'dataset.pt')
        x, y, sigma = torch.load(dataset_file)
        self.x = x
        self.y = y
        self.sigma = sigma

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        # TODO: data transformers
        return self.x[idx], self.y[idx]


dataset = CustomImageDataset()
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

In [3]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

In [4]:
index = 1
train_features, train_labels = next(iter(train_dataloader))
outputs_dir = os.path.join(os.getcwd(), "outputs")
cv2.imwrite(os.path.join(outputs_dir, f"output_{index}.jpg"), train_features[index].numpy().transpose((1, 2, 0)))

[ WARN:0@3.935] global loadsave.cpp:1063 imwrite_ Unsupported depth image for selected encoder is fallbacked to CV_8U.


True

In [5]:
def determine_device():
    if torch.cuda.is_available():
        return 'cuda'
    elif torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'


device = 'cpu'
print(f'Device is {device}')

Device is cpu


In [6]:
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.head = torch.nn.Linear(model.head.in_features, 1)
model.to(device)
print(model)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False

In [7]:
train_features, train_labels = next(iter(train_dataloader))
model(train_features)


tensor([[-0.8902],
        [-0.9146],
        [-0.7982],
        [-1.0480],
        [-1.0954],
        [-0.7769],
        [-1.7268],
        [-1.0766],
        [-0.8595],
        [-1.0234],
        [-0.7655],
        [-0.7718],
        [-1.2850],
        [-1.0694],
        [-1.1243],
        [-1.1406],
        [-0.4718],
        [-1.0263],
        [-1.5033],
        [-0.6786],
        [-1.0504],
        [-1.5661],
        [-0.9432],
        [-1.4443],
        [-0.9487],
        [-0.9796],
        [-1.2428],
        [-1.4924],
        [-0.4636],
        [-1.2577],
        [-0.3153],
        [-0.9270],
        [-1.1845],
        [-1.0517],
        [-1.2155],
        [-0.9963],
        [-1.3562],
        [-1.1281],
        [-0.9804],
        [-1.2161],
        [-1.2281],
        [-1.2612],
        [-0.3056],
        [-1.6277],
        [-1.4559],
        [-1.1971],
        [-0.9529],
        [-1.1307],
        [-1.0959],
        [-1.6332],
        [-0.8067],
        [-0.9862],
        [-1.

In [7]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
criterion = MSELoss()

In [9]:
model.train()
for x, y in train_dataloader:
    x, y = x.to(device), y.unsqueeze(dim=-1).to(device)
    outputs = model(x)
    loss = criterion(outputs, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f'Loss is {loss}')

Loss is 2.3991472721099854
Loss is 1.6404043436050415
Loss is 3.6473677158355713
Loss is 4.84865140914917
Loss is 2.3598921298980713
Loss is 1.1724205017089844
Loss is 0.5166054368019104
Loss is 0.8146299123764038
Loss is 1.578791856765747
Loss is 1.565467119216919
Loss is 1.0256702899932861


KeyboardInterrupt: 