In [1]:
%load_ext autoreload
%autoreload 2

import os,sys,inspect
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, os.path.join(parent_dir)) 
import numpy as np
import pandas as pd

from coord2vec.models.data_loading.tile_features_loader import TileFeaturesDataset
from coord2vec.models.model_utils import get_data_loader, get_pytorch_dataset
from coord2vec.config import CACHE_DIR

In [2]:
# load the data
data_loader = get_data_loader()

In [19]:
data_loader.dataset[0][1]

tensor([507385.0625,         nan,      0.0000,         nan,      0.0000])

In [34]:
# Create the NN
from coord2vec.models.architectures import resnet18, multihead_model, dual_fc_head
from coord2vec.models.losses import multihead_loss

import torch.nn as nn
import torch.optim as optim
import torch

z_dim = 128
mtl_head_sizes = (3,2)
n_channels = (data_loader.dataset[0][0].shape[0])

model = resnet18(n_channels, z_dim)
head1 = dual_fc_head(z_dim, n_classes=mtl_head_sizes[0])
head2 = dual_fc_head(z_dim, n_classes=mtl_head_sizes[1])
model = multihead_model(model, [head1, head2])

# create the losses
criterion = multihead_loss([nn.L1Loss(), nn.L1Loss()])
optimizer = optim.Adam(model.parameters())

In [35]:
# Train a model
max_epochs = 2

# Loop over epochs
for epoch in range(max_epochs):
    # Training
    for images_batch, features_batch in data_loader:
        
        # split the features into the multi_heads:
        split_features_batch = torch.split(features_batch, mtl_head_sizes, dim=1)
        
        optimizer.zero_grad()
        
        output = model.forward(images_batch)[1]
        loss = criterion(output, split_features_batch)
        loss.backward()
        optimizer.step()

In [38]:
x_pred = torch.ones(3,2,224,224)
y_pred = model(x_pred)
y_pred[0].shape

torch.Size([3, 128])

In [28]:
x_pred = data_loader.dataset[0][0][None,:,:,:]
y_pred = model(x_pred)
y_pred[0]

tensor([-37.5586,  -5.7906,   3.3974,  17.2079,  -4.3377,  -5.1141, -21.1913,
         42.6451,  24.8207,  14.1980, -38.6404, -26.9991, -22.4609,  -1.6728,
          9.2489, -12.6510, -31.5708, -26.6453, -15.1955,  37.4996,  -2.7640,
        -30.1221,  18.4570, -42.1786,  -1.1221, -14.3957,  39.2992, -31.2374,
        -35.8495,  32.0851, -10.8533, -31.9567,  -8.2693, -30.8625,  39.6849,
         13.4793,  -3.6356,  -0.7105, -41.6769,  -0.7229,   6.6207,  22.6575,
         -5.6362,  35.7838, -41.1668, -34.6632,  23.2290, -14.5626, -38.9101,
        -18.0895,  22.6260,  21.0566,  39.2145,   3.9535,  35.2733,   5.6688,
          4.0585, -42.6405, -42.3857,  26.8344,  19.4617, -33.0107, -20.7948,
         13.8923, -42.3280,  26.5748, -24.3154,   3.1902,  12.8675,  -5.0913,
          3.5130, -32.5561,  -3.4316,   2.7310, -35.5863,  37.4442, -18.2592,
         34.8666, -43.3461, -37.1842, -19.2268,  20.1802, -33.1238,   0.4628,
         19.0287,  -2.6069,  -0.6670,  35.6660,  17.5473, -24.00