In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn.functional as F

import sys
sys.path.append('../dgmr_deterministic')
import dgmr

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [6]:
FORECAST = 24
INPUT_STEPS = 4

input_ccs = dgmr.common.ContextConditioningStack(
    input_channels=1,
    conv_type='standard',
    output_channels=192,
    num_context_steps=INPUT_STEPS,
)

last_ccs = dgmr.common.ContextConditioningStack(
    input_channels=1,
    conv_type='standard',
    output_channels=256,
    num_context_steps=1,
)

sampler = dgmr.generators.Sampler(
    forecast_steps=FORECAST,
    latent_channels=256,
    context_channels=192,
)

model = dgmr.generators.Generator(input_ccs, last_ccs, sampler)
model = model.to(DEVICE)

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model)} trainable parameters')

The model has 13945136 trainable parameters


In [8]:
x_input = torch.rand((1, INPUT_STEPS, 1, 128, 128)).to(DEVICE)
x_last = torch.rand((1, 1, 1, 128, 128)).to(DEVICE)

In [9]:
out = model(x_input, x_last)



In [10]:
out.shape

torch.Size([1, 24, 1, 128, 128])

In [11]:
y = torch.rand((1, FORECAST, 1, 128, 128)).to(DEVICE)
loss = F.mse_loss(y, out)

In [12]:
loss

tensor(0.7657, device='cuda:0', grad_fn=<MseLossBackward0>)

In [13]:
loss.backward()