In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import matplotlib.pyplot as plt
import timm
import torch
import torch.nn as nn
import xarray as xr

In [None]:
from crims2s.dataset import S2SDataset, TransformedDataset
from crims2s.training.model.bayes import Projection
from crims2s.util import ECMWF_FORECASTS, collate_with_xarray

In [None]:
with hydra.initialize_config_module('crims2s.training.conf'):
    cfg = hydra.compose('config', overrides=['experiment=conv_fcn'])

In [None]:
t = hydra.utils.instantiate(cfg.experiment.transform)
d = TransformedDataset(S2SDataset(cfg.experiment.dataset.dataset_dir), t)

In [None]:
loader = torch.utils.data.DataLoader(d, batch_size=4, collate_fn=collate_with_xarray)

In [None]:
b = next(iter(loader))

In [None]:
pretrained_model = torch.hub.load('pytorch/vision:v0.10.0', 'fcn_resnet50', pretrained=True)
pretrained_model.classifier[4] = nn.Conv2d(512, 8, kernel_size=3, stride=1)

In [None]:
class PretrainedModelWrapper(nn.Module):
    def __init__(self, pretrained_model, in_features):
        super().__init__()
        
        self.projection = Projection(in_features, 3, moments=False, flatten_time=True)
        self.pretrained = pretrained_model
        
    def forward(self, batch):
        x = batch['features_features']
        x = torch.transpose(x, -1, 1)  # Swap channels and time dim.
        
        print('before projection', x.shape)
        
        x = self.projection(x)
        x = x.mean(-1)  # Remove time dimension (which was flattened by the projection).
        
        print('after projection', x.shape)
        
        x = self.pretrained.forward(x)['out']
        
        print('after pretrained', x.shape)
        
        batch_size = x.shape[0]
        
        x_t2m = x[:, :4]
        x_tp = x[:, 4:]
        
        x_t2m = x_t2m.reshape(batch_size, 2, 121, 240, 2)
        x_tp = x_tp.reshape(batch_size, 2, 121, 240, 2)
        
        print('x_t2m', x_t2m.shape)
                
        return x_t2m, x_tp
        

In [None]:
m = PretrainedModelWrapper(pretrained_model, 17)

In [None]:
m = hydra.utils.instantiate(cfg.experiment.model)

In [None]:
t2m, tp = m(b)

In [None]:
t2m.shape

In [None]:
plt.imshow(t2m[0,0,0].detach().numpy())

In [None]:
for k, v in m.named_parameters():
    print(k)