In [19]:
import os
import preprocess
import pytorch_lightning as pl
import torch
import numpy as np
from rasterio.plot import show
from torch.utils.data import DataLoader
from torch.nn import functional as F
from pytorch_lightning.loggers import TensorBoardLogger

In [15]:
class LitUNet(pl.LightningModule):
    
    def __init__(self, file_pairs, input_num=4, output_num=1, initial_feat=32, trained=False):
        super().__init__()
        self.model = model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=input_num, out_channels=output_num,
                                            init_features=initial_feat, pretrained=trained)
        self.file_pairs = file_pairs
        self.criterion = torch.nn.MSELoss(reduction="mean")
        
    def forward(self, x):
        return self.model(x)
    
    def prepare_data(self):
        all_data = preprocess.GISDataset(self.file_pairs)
        # calculate the splits
        total = len(all_data)
        train = int(total*.7)
        val = int(total*.15)
        if train+(val*2) != total:
            diff = total-train-(val*2)
            train += diff
        self.train_set, self.validate_set, self.test_set = torch.utils.data.random_split(all_data, [train, val, val])
        
    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size = 64, num_workers=10)
    
    def val_dataloader(self):
        return DataLoader(self.validate_set, batch_size=64, num_workers=10)
    
    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=64, num_workers=10)
    
    def configure_optimizers(self):
        optimizer=torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def training_step(self, train_batch, batch_idx):
        x = train_batch['image']
        y = train_batch['mask'].unsqueeze(1)
        #x, y = train_batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        
        logs = {'train_loss':loss}
        return {'loss':loss, 'log':logs}
    
    def test_step(self, batch, batch_idx):
        x = batch['image']
        y = batch['mask'].unsqueeze(1)
        #x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        return {'test_loss':loss}
    
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss':avg_loss}
        return {'avg_test_loss':avg_loss, 'log':tensorboard_logs}
    
    def validation_step(self, val_batch, batch_idx):
        x = val_batch['image']
        y = val_batch['mask'].unsqueeze(1)
        #x, y = val_batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        return {'val_loss':loss}
    
    def validation_epoch_end(self, outputs):
        # called at the end of the validation epoch
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss':avg_loss}
        return {'avg_val_loss':avg_loss, 'log':tensorboard_logs}

In [16]:
twelve_img = "/vol/ml/EphemeralStreamData/Ephemeral_Channels/Imagery/vhr_2012_refl.img"
twelve_shp = "/vol/ml/EphemeralStreamData/Ephemeral_Channels/Reference/reference_2012_merge.shp"

In [17]:
model = LitUNet([(twelve_img, twelve_shp)])

Using cache found in /homes/mzvyagin/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


In [20]:
mylogger = TensorBoardLogger(save_dir=os.getcwd(), version=1, name='lightning_logs')

In [23]:
trainer = pl.Trainer(gpus=[2], max_epochs=100, logger=mylogger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [2]


In [None]:
trainer.fit(model)

  return _prepare_from_string(" ".join(pjargs))

  | Name      | Type    | Params
--------------------------------------
0 | model     | UNet    | 7 M   
1 | criterion | MSELoss | 0     


Epoch 1:  82%|████████▏ | 130/158 [00:57<00:12,  2.28it/s, loss=0.125, v_num=1]
Validating: 0it [00:00, ?it/s][A
Epoch 1:  83%|████████▎ | 131/158 [01:00<00:12,  2.16it/s, loss=0.125, v_num=1]
Epoch 1:  84%|████████▎ | 132/158 [01:00<00:11,  2.18it/s, loss=0.125, v_num=1]
Epoch 1:  84%|████████▍ | 133/158 [01:00<00:11,  2.19it/s, loss=0.125, v_num=1]
Epoch 1:  85%|████████▍ | 134/158 [01:00<00:10,  2.20it/s, loss=0.125, v_num=1]
Epoch 1:  85%|████████▌ | 135/158 [01:01<00:10,  2.21it/s, loss=0.125, v_num=1]
Epoch 1:  86%|████████▌ | 136/158 [01:01<00:09,  2.22it/s, loss=0.125, v_num=1]
Epoch 1:  87%|████████▋ | 137/158 [01:01<00:09,  2.24it/s, loss=0.125, v_num=1]
Epoch 1:  87%|████████▋ | 138/158 [01:01<00:08,  2.25it/s, loss=0.125, v_num=1]
Epoch 1:  88%|████████▊ | 139/158 [01:01<00:08,  2.26it/s, loss=0.125, v_num=1]
Epoch 1:  89%|████████▊ | 140/158 [01:01<00:07,  2.27it/s, loss=0.125, v_num=1]
Epoch 1:  89%|████████▉ | 141/158 [01:01<00:07,  2.28it/s, loss=0.125, v_num=1]
Epoch 