In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.dm import DataModule

dm = DataModule(batch_size=4)
dm.setup()

train: 8689
test: 2773


In [3]:
x, labels = next(iter(dm.train_dataloader()))

s1, s2 = x
s1.shape, s2.shape, s2.max(), s2.min()

(torch.Size([4, 12, 2, 256, 256]),
 torch.Size([4, 12, 3, 256, 256]),
 tensor(1.),
 tensor(0.))

In [4]:
from src.models.unet_ltae import UNetLTAE

hparams = {
	'encoder': 'resnet18',
	'pretrained': None, 
	'in_channels_s1': 2,
	'in_channels_s2': 3,
	'n_head': 16
}

model = UNetLTAE(hparams).cuda()

x = (s1.cuda(), s2.cuda())
outputs = model(x)

outputs.shape

torch.Size([16384, 1, 24]) torch.Size([16384, 24, 16])
torch.Size([65536, 1, 24]) torch.Size([65536, 24, 8])
torch.Size([262144, 1, 24]) torch.Size([262144, 24, 4])
torch.Size([1048576, 1, 24]) torch.Size([1048576, 24, 4])


torch.Size([4, 256, 256])

In [5]:
import pytorch_lightning as pl

dm = DataModule(batch_size=4)

hparams = {
	'encoder': 'resnet18',
	'pretrained': 'imagenet',
	'in_channels_s1': 2,
	'in_channels_s2': 3,
	'optimizer': 'Adam',
	'n_head': 16,
	'optimizer_params': {
		'lr': 1e-3
	},
}

module = UNetLTAE(hparams)

trainer = pl.Trainer(
	gpus=1,
	precision=16,
	overfit_batches=1,
	max_epochs=300,
	logger=None,
	enable_checkpointing=False,
)

trainer.fit(module, dm)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name              | Type             | Params
-------------------------------------------------------
0 | encoder1          | ResNetEncoder    | 11.2 M
1 | encoder2          | ResNetEncoder    | 11.2 M
2 | decoder           | UnetDecoder      | 3.2 M 
3 | segmentation_head | SegmentationHead | 145   
4 | ltae              | LTAE             | 1.3 M 
-------------------------------------------------------
26.7 M    Trainable params
25.6 K    Non-trainable params
26.8 M    Total params
53.522    Total estimated model params size (MB)


train: 8689
test: 2773


  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

In [4]:
from src.models.ltae import LTAE

ltae = LTAE(
	in_channels=512,
	len_max_seq=24,
	return_att=True,
	n_neurons=[512*2, 512],
	d_model=512*2,
).cuda()

In [12]:
import torch 

x = torch.randn(4, 24, 512, 8, 8) # sample feature at lowest level
x.shape

torch.Size([4, 24, 512, 8, 8])

In [13]:
# apply shared ltae to each pixel

from einops import rearrange

x = rearrange(x, 'b t c h w -> (b h w) t c')
x.shape 

torch.Size([256, 24, 512])

In [15]:
out, att = ltae(x.cuda())

In [18]:
out.shape, rearrange(out, '(b h w) c -> b c h w', h=8, w=8).shape

(torch.Size([256, 512]), torch.Size([4, 512, 8, 8]))

In [19]:
att.shape # N heads, batch, seq_len

torch.Size([16, 256, 24])

In [20]:
rearrange(att, 'nh (b h w) t -> nh b t h w', h=8, w=8).shape

torch.Size([16, 4, 24, 8, 8])