In [1]:
import torch
import torch.nn.functional as F
from torchinfo import summary
import warnings
warnings.filterwarnings("ignore")

In [2]:
from metnet import MetNet2
import torch
import torch.nn.functional as F

model = MetNet2(
        forecast_steps=24,
        input_size=128,
        input_channels=3,
        sat_channels=1,
        num_input_timesteps=12,
        upsampler_channels=64,
        lstm_channels=32,
        num_context_blocks=2,
        encoder_channels=64,
        upsampler_channel=128,
        lead_time_features=128,
        output_channels=1,
        )

# summary(model, (2, 12, 3, 128, 128))

# MetNet expects original HxW to be 4x the input size
x = torch.randn((2, 12, 3, 128, 128))
out = []
for lead_time in range(3):
        out.append(model(x, lead_time))
out = torch.stack(out, dim=1)
out.shape
# y = torch.rand((2,8,12,64,64))
# F.mse_loss(out, y).backward()

torch.Size([2, 3, 1, 128, 128])

In [10]:
summary(model, (2, 12, 3, 128, 128))

Layer (type:depth-idx)                        Output Shape              Param #
MetNet2                                       --                        --
├─ConditionWithTimeMetNet2: 1                 --                        --
│    └─ModuleList: 2-1                        --                        --
├─ConvLSTM: 1                                 --                        --
│    └─ModuleList: 2-2                        --                        --
├─TimeDistributed: 1                          --                        --
│    └─DownSampler: 2                         --                        --
│    │    └─Sequential: 3-1                   --                        9,303
├─ModuleList: 1-1                             --                        --
├─ModuleList: 1-2                             --                        --
├─ModuleList: 1-3                             --                        --
├─TimeDistributed: 1-4                        [2, 12, 3, 32, 32]        --
│    └─DownSample

In [3]:
from metnet import MetNet

model = MetNet(
        hidden_dim=24,
        forecast_steps=24,
        input_channels=3,
        output_channels=1,
        sat_channels=1,
        input_size=128,
        )

x = torch.randn((4, 12, 3, 128, 128))

model(x, 0).shape

torch.Size([4, 1, 64, 64])

In [4]:
summary(model, (4, 12, 3, 128, 128))

Layer (type:depth-idx)                        Output Shape              Param #
MetNet                                        --                        --
├─TemporalEncoder: 1                          --                        --
│    └─ConvGRU: 2                             --                        --
│    │    └─ModuleList: 3-1                   --                        57,120
│    │    └─ModuleList: 3-2                   --                        --
├─Sequential: 1                               --                        --
│    └─AxialAttention: 2                      --                        --
│    │    └─ModuleList: 3-3                   --                        4,656
├─DownSample: 1-1                             [4, 12, 3, 64, 64]        --
│    └─Sequential: 2                          --                        --
│    │    └─MaxPool2d: 3-4                    [4, 36, 64, 64]           --
├─ConditionTime: 1-2                          [4, 12, 27, 64, 64]       --
├─TimeDistrib

In [6]:
out = []

for lead_time in range(3):
    output = model(x, lead_time)
    output = torch.nn.ConvTranspose2d(1, 1, kernel_size=3, stride=2, padding=1, output_padding=1)(output)
    out.append(output)
out = torch.stack(out, dim=1)
out.shape

torch.Size([2, 3, 1, 63, 63])

In [4]:
out = torch.reshape(out, (2, 3, 64, 64, 1))
out

tensor([[[[[0.1498],
           [0.1448],
           [0.1703],
           ...,
           [0.1835],
           [0.1599],
           [0.1651]],

          [[0.1548],
           [0.1446],
           [0.1713],
           ...,
           [0.1929],
           [0.1623],
           [0.1666]],

          [[0.1420],
           [0.1366],
           [0.1646],
           ...,
           [0.1743],
           [0.1501],
           [0.1565]],

          ...,

          [[0.1492],
           [0.1447],
           [0.1682],
           ...,
           [0.1814],
           [0.1565],
           [0.1626]],

          [[0.1533],
           [0.1464],
           [0.1692],
           ...,
           [0.1836],
           [0.1561],
           [0.1656]],

          [[0.1351],
           [0.1290],
           [0.1548],
           ...,
           [0.1690],
           [0.1440],
           [0.1494]]],


         [[[0.1679],
           [0.1429],
           [0.1956],
           ...,
           [0.1868],
           [0.1615

In [6]:
torch.min(out)

tensor(0.0999, grad_fn=<MinBackward1>)

In [4]:
a = torch.randn((2, 1, 16, 16))
a.shape

torch.Size([2, 1, 16, 16])

In [10]:
torch.nn.ConvTranspose2d(1, 1, kernel_size=3, stride=2, padding=1, output_padding=1)(a)

False

In [16]:
import numpy as np

a1 = np.arange(128*128).reshape(128, 128)
a2 = np.arange(128*128).reshape(128, 128)
a = np.stack([a1, a2])
a = np.expand_dims(a, 0)  # (1, 2, _, _)
b = np.repeat(a, repeats=12, axis=0)
b.shape

(12, 2, 128, 128)

In [17]:
c = np.ones((12, 1, 128, 128))

In [18]:
np.concatenate([c, b], axis=1).shape

(12, 3, 128, 128)