In [1]:
from fit.datamodules.tomo_rec import MNISTTomoFourierTargetDataModule
from fit.utils import convert2FC, fft_interpolate
from fit.utils.tomo_utils import get_proj_coords, get_img_coords

from fit.modules import TRecTransformerModule

from matplotlib import pyplot as plt

import torch

import numpy as np

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

In [2]:
seed_everything(22122020)

22122020

In [3]:
dm = MNISTTomoFourierTargetDataModule(root_dir='/home/tibuch/Data/mnist/', batch_size=16, num_angles=15)
dm.setup()

In [4]:
train_dl = dm.train_dataloader()
val_dl = dm.val_dataloader()
test_dl = dm.test_dataloader()

In [5]:
proj_xcoords, proj_ycoords = get_proj_coords(angles=dm.gt_ds.get_ray_trafo().geometry.angles, 
                                             img_shape=dm.IMG_SHAPE)
target_xcoords, target_ycoords = get_img_coords(img_shape=dm.IMG_SHAPE, endpoint=False)

In [6]:
model = TRecTransformerModule(d_model=128, y_coords_proj=proj_ycoords, x_coords_proj=proj_xcoords,
                             y_coords_img=target_ycoords, x_coords_img=target_xcoords, 
                             angles=dm.gt_ds.get_ray_trafo().geometry.angles, img_shape=dm.IMG_SHAPE,
                             lr=0.0001, weight_decay=0.01, loss_switch=0.4, attention_type='linear', n_layers=4,
                             n_heads=4, d_query=128//4, dropout=0.1, attention_dropout=0.1)

In [7]:
trainer = Trainer(max_epochs=5, 
                  gpus=1, limit_train_batches=100, limit_val_batches=10,
                  checkpoint_callback=ModelCheckpoint(
                                            filepath=None,
                                            save_top_k=1,
                                            verbose=False,
                                            save_last=True,
                                            monitor='Train/avg_val_loss',
                                            mode='min',
                                            prefix='best_val_loss_'
                                        ), 
                  deterministic=True)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [8]:
model.trec

TRecTransformer(
  (fourier_coefficient_embedding): Linear(in_features=2, out_features=64, bias=True)
  (pos_embedding_input_projections): PositionalEncoding2D(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (attention): AttentionLayer(
          (inner_attention): LinearAttention(
            (feature_map): ActivationFunctionFeatureMap()
          )
          (query_projection): Linear(in_features=128, out_features=128, bias=True)
          (key_projection): Linear(in_features=128, out_features=128, bias=True)
          (value_projection): Linear(in_features=128, out_features=128, bias=True)
          (out_projection): Linear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_af

In [9]:
trainer.fit(model, train_dl, val_dl);


  | Name | Type            | Params
-----------------------------------------
0 | trec | TRecTransformer | 1.9 M 
-----------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at  /opt/conda/conda-bld/pytorch_1607370128159/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2/4: Switched to real loss.


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Saving latest checkpoint...



