In [1]:
import torch

In [2]:
import pandas as pd
from biomasstry.models import UTAE
from biomasstry.datasets import TemporalSentinel2Dataset

In [3]:
from torch.utils.data import DataLoader

In [4]:
s3_url = "s3://drivendata-competition-biomassters-public-us"
metadata_file = "/notebooks/data/metadata_parquet/features_metadata_slim.parquet"
metadata_df = pd.read_parquet(metadata_file)
chip_ids = metadata_df[metadata_df.split == "train"].chip_id.unique().tolist()

In [5]:
ds = TemporalSentinel2Dataset(chip_ids, data_url=s3_url)

In [6]:
dl = DataLoader(ds, batch_size=4)

In [7]:
b = next(iter(dl))

In [8]:
model = UTAE(10)

In [9]:
x, y, cid = b

In [10]:
x.shape

torch.Size([4, 5, 10, 256, 256])

In [11]:
yb = model(x)

In [12]:
yb.size()

torch.Size([4, 1, 256, 256])

In [14]:
y0 = model(x[0].unsqueeze(dim=0))  # batch size of 1 to avoid kernel crashing

In [15]:
y0.size()

torch.Size([1, 1, 256, 256])

In [18]:
pad_mask = (x == model.pad_value).all(dim=-1).all(dim=-1).all(dim=-1)
print(f"pad mask: {pad_mask.size()}")

pad mask: torch.Size([4, 5])


In [19]:
out = model.in_conv.smart_forward(x)
print(f"out: {out.shape}")

out: torch.Size([4, 5, 64, 256, 256])


In [20]:
feature_maps = [out]
# SPATIAL ENCODER
for i in range(model.n_stages - 1):
    out = model.down_blocks[i].smart_forward(feature_maps[-1])
    print(f"{i}: {out.size()}")
    feature_maps.append(out)

0: torch.Size([4, 5, 64, 128, 128])
1: torch.Size([4, 5, 64, 64, 64])
2: torch.Size([4, 5, 128, 32, 32])


In [21]:
for f in feature_maps:
    print(f.size())

torch.Size([4, 5, 64, 256, 256])
torch.Size([4, 5, 64, 128, 128])
torch.Size([4, 5, 64, 64, 64])
torch.Size([4, 5, 128, 32, 32])


In [23]:
# TEMPORAL ENCODER
out, att = model.temporal_encoder(
    feature_maps[-1], batch_positions=None, pad_mask=pad_mask
)
print(f"Out: {out.size()}, Att: {att.size()}")

Out: torch.Size([4, 128, 32, 32]), Att: torch.Size([16, 4, 5, 32, 32])


In [18]:
for i in range(model.n_stages - 1):
    print(feature_maps[-(i + 2)].size())

torch.Size([4, 5, 64, 64, 64])
torch.Size([4, 5, 64, 128, 128])
torch.Size([4, 5, 64, 256, 256])


In [24]:
model.return_maps

False

In [25]:
# SPATIAL DECODER
if model.return_maps:
    maps = [out]
for i in range(model.n_stages - 1):
    skip = model.temporal_aggregator(
        feature_maps[-(i + 2)], pad_mask=pad_mask, attn_mask=att
    )
    out = model.up_blocks[i](out, skip)
    if model.return_maps:
        maps.append(out)
    print(f"Fm({-(i + 2)}): {feature_maps[-(i + 2)].size()}")
    print(f"skip({i}): {skip.size()}")
    print(f"out({i}): {out.size()}")

Fm(-2): torch.Size([4, 5, 64, 64, 64])
skip(0): torch.Size([4, 64, 64, 64])
out(0): torch.Size([4, 64, 64, 64])
Fm(-3): torch.Size([4, 5, 64, 128, 128])
skip(1): torch.Size([4, 64, 128, 128])
out(1): torch.Size([4, 32, 128, 128])
Fm(-4): torch.Size([4, 5, 64, 256, 256])
skip(2): torch.Size([4, 64, 256, 256])
out(2): torch.Size([4, 32, 256, 256])


In [27]:
out_c = model.out_conv(out)
print(f"out conv: {out_c.size()}")

out conv: torch.Size([4, 1, 256, 256])


In [20]:
y.size()

torch.Size([4, 1, 256, 256])