In [1]:
import torch
import torchvision
import pytorch_lightning as pl
from matplotlib import pyplot as plt

from Models import U_net, PlottingCallback
from Simulation import DataGenerator
import warnings
warnings.filterwarnings("ignore") 

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Dataset prep
DataGenerator.img_size = 200
dt = DataGenerator('train', size=100)
dv = DataGenerator('valid', size=10)
dt.generate()
dv.generate()

train_loader = torch.utils.data.DataLoader(dt, batch_size = 10, shuffle = True)
valid_loader = torch.utils.data.DataLoader(dv, batch_size = 10, shuffle = False)

# Model and Train
model = U_net(img_size=[dt.img_size, dt.img_size], learning_rate=1e-3)
trainer = pl.Trainer(gpus=-1, fast_dev_run = False, progress_bar_refresh_rate=20, callbacks=[PlottingCallback(dataloader=valid_loader)])
trainer.fit(model, train_loader, valid_loader)

100%|██████████| 100/100 [00:00<00:00, 269.52it/s]
100%|██████████| 10/10 [00:00<00:00, 263.75it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type       | Params | In sizes                             | Out sizes                           
-----------------------------------------------------------------------------------------------------------------------------
0 | enc_1           | Encoder    | 828    | [1, 1, 200, 200]                     | [[1, 9, 100, 100], [1, 9, 100, 100]]
1 | enc_2           | Encoder    | 4.4 K  | [1, 9, 100, 100]                     | [[1, 18, 50, 50], [1, 18, 50, 50]]  
2 | enc_3           | Encoder    | 17.6 K | [1, 18, 50, 50]                      | [[1, 36, 25, 25], [1, 36, 25, 25]]  
3 | dec_3           | Decoder    | 8.8 K  | [[1, 36, 25, 25], [1, 36, 25, 25]]   | [1, 18, 50, 50]        

Epoch 0: 100%|██████████| 11/11 [00:02<00:00,  4.37it/s, loss=0.63, v_num=10]0
Epoch 10: 100%|██████████| 11/11 [00:00<00:00, 15.09it/s, loss=0.227, v_num=10]10
Epoch 20: 100%|██████████| 11/11 [00:00<00:00, 13.98it/s, loss=0.0365, v_num=10]20
Epoch 30: 100%|██████████| 11/11 [00:00<00:00, 14.67it/s, loss=0.0186, v_num=10]30
Epoch 40: 100%|██████████| 11/11 [00:00<00:00, 15.32it/s, loss=0.0153, v_num=10]40
Epoch 50: 100%|██████████| 11/11 [00:00<00:00, 14.88it/s, loss=0.0128, v_num=10]50
Epoch 60: 100%|██████████| 11/11 [00:00<00:00, 15.47it/s, loss=0.0127, v_num=10]60
Epoch 70: 100%|██████████| 11/11 [00:00<00:00, 14.40it/s, loss=0.01, v_num=10]  70
Epoch 80: 100%|██████████| 11/11 [00:00<00:00, 15.60it/s, loss=0.00841, v_num=10]80
Epoch 90: 100%|██████████| 11/11 [00:00<00:00, 13.82it/s, loss=0.00834, v_num=10]90
Epoch 100: 100%|██████████| 11/11 [00:00<00:00, 14.63it/s, loss=0.00734, v_num=10]100
Epoch 110: 100%|██████████| 11/11 [00:00<00:00, 14.69it/s, loss=0.00748, v_num=10]110
E

Epoch 146:   0%|          | 0/11 [00:15<?, ?it/s, loss=0.00637, v_num=10]