<a href="https://colab.research.google.com/github/ivlucky/freelance_steel_defects/blob/master/2_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Magic

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# Installation

In [2]:
!pip install xmltodict==0.12.0 --quiet
!pip install pytorch_lightning==1.4.9 --quiet
!pip install segmentation-models-pytorch==0.2.0 --quiet
!pip install albumentations==1.1.0 --quiet
!pip install tensorboard==2.7.0 --quiet # for monitoring training
# for tpu training. See https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html
#!pip install cloud-tpu-client https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

[K     |████████████████████████████████| 925 kB 5.4 MB/s 
[K     |████████████████████████████████| 829 kB 33.7 MB/s 
[K     |████████████████████████████████| 282 kB 44.3 MB/s 
[K     |████████████████████████████████| 125 kB 48.1 MB/s 
[K     |████████████████████████████████| 596 kB 47.2 MB/s 
[K     |████████████████████████████████| 1.3 MB 44.1 MB/s 
[K     |████████████████████████████████| 271 kB 47.5 MB/s 
[K     |████████████████████████████████| 160 kB 47.7 MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 87 kB 3.1 MB/s 
[K     |████████████████████████████████| 376 kB 39.2 MB/s 
[K     |████████████████████████████████| 58 kB 5.8 MB/s 
[?25h  Building wheel for efficientnet-pytorch (setup.py) ... [?25l[?25hdone
  Building wheel for pretrainedmodels (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 102 kB 5.5 MB/s 
[K     |████████████████████████████████| 47.6 MB 36 kB/s

# Imports

In [3]:
import os
import xmltodict
import numpy as np
from typing import Dict
from PIL import Image
from datetime import date
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pytorch_lightning as pl
from pytorch_lightning import Trainer, callbacks
import segmentation_models_pytorch as smp
import albumentations as albu

from sklearn.model_selection import train_test_split
import itertools

# Mount drive

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Global variables

In [5]:
DATAFOLDER = "/content/drive/MyDrive/data/Project_steel_plate_defect_detection_dataset"
INPUTFOLDER = os.path.join(DATAFOLDER, "input")
ANNOFOLDER = os.path.join(INPUTFOLDER, "anno")
IMAGESFOLDER = os.path.join(INPUTFOLDER, "images")
AUGMENTEDFOLDER = os.path.join(DATAFOLDER, "augmented")
PROCESSEDFOLDER = os.path.join(DATAFOLDER, "processed")
MASKFOLDER = os.path.join(PROCESSEDFOLDER, "masks")
SPLITFOLDER = os.path.join(DATAFOLDER, "split")
OUTPUTFOLDER = os.path.join(DATAFOLDER, "output")
LOGFOLDER = os.path.join(OUTPUTFOLDER, "logs")
MODELFOLDER = os.path.join(OUTPUTFOLDER, "models")


RANDOM_STATE = 42
VALSIZE = 0.2
TESTSIZE = 0.1
THRESHOLD = 0.5

ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
NUMCLASSES = 1
INPUTCHANNELS = 3
# Note: don't use activation in model level because of process stability
ACTIVATION = None#'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0')
else:
    DEVICE = torch.device('cpu')
DEVICE

device(type='cuda', index=0)

# Functions

In [6]:
def get_xmldict(xmlfile):

    with open(xmlfile, 'r') as fp:

        xmlcontent = fp.read()
        xmldict = xmltodict.parse(xmlcontent)

    return xmldict

def set_bndbox(mask, bndbox):

    xmin, xmax = int(bndbox['xmin']), int(bndbox['xmax'])
    ymin, ymax = int(bndbox['ymin']), int(bndbox['ymax'])

    mask[xmin:xmax, ymin:ymax] = 1

def set_allbndbox(mask, objects):

    for obj in objects:
        set_bndbox(mask, obj['bndbox'])

def get_mask(xmlfile):

    xmldict = get_xmldict(xmlfile)

    mask = np.zeros((200, 200))
    if isinstance(xmldict['annotation']['object'], Dict):
        set_bndbox(mask, xmldict['annotation']['object']['bndbox'])
    else:
        set_allbndbox(mask, xmldict['annotation']['object'])

    return mask.T

def plot_image(img):

  print(img.shape)
  fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(14, 14))
  ax.imshow(img)
  ax.axis('off')
  return fig, ax

def show_image(imgfile):

  img = get_image(imgfile)
  return plot_image(img)

def plot_mask(mask, fig, ax):

  ax.imshow(mask, alpha=0.3)

def show_mask(xmlfile, fig, ax):

  mask = get_mask(xmlfile)
  plot_mask(mask, fig, ax)

def show_obj(object_val, datafolder='./'):

    imgfile = os.path.join(datafolder, 'images', object_val + '.jpg')
    xmlfile = os.path.join(datafolder, 'anno', object_val + '.xml')

    fig, ax = show_image(imgfile)
    show_mask(xmlfile, fig, ax)

def sets_describe(set_1, set_2):

    print(f"set_1 len: {len(set_1)}")
    print(f"set_2 len: {len(set_2)}")
    print(f"set_1 - set_2 len: {len(set_1 - set_2)}")
    print(f"set_2 - set_1 len: {len(set_2 - set_1)}")
    print(f"symmetric diff len: {len(set_2.symmetric_difference(set_1))}")
    print(f"intersection len: {len(set_1.intersection(set_2))}")
    print(f"union len: {len(set_1.union(set_2))}")

def get_image(imgfile):

    return mpimg.imread(imgfile)

def save_mask(maskfile, mask):

  np.save(maskfile, mask)

def load_mask(maskfile):

  return np.load(maskfile)

# Classes

In [7]:
class SegDataset(torch.utils.data.Dataset):
  """Initial Dataset for segmentation images.
  
  Args:
  
      obj_list (list): list of object for this dataset
      images_dir (str): path to images folder
      masks_dir (str): path to segmentation masks folder
  
  """
  
  def __init__(self, obj_list, img_dir, mask_dir, transforms=None):

    self.obj_list = obj_list
    self.img_dir = img_dir
    self.mask_dir = mask_dir
    self.transforms = transforms
  
  def __getitem__(self, i):
      
    # read data
    if i > len(self.obj_list):
        raise KeyError('i not in self.obj_list')
    if i < 0:
        raise KeyError('i is negative')

    obj = self.obj_list[i]
    imgfile = os.path.join(self.img_dir, '.'.join([obj, 'jpg']))
    maskfile = os.path.join(self.mask_dir, '.'.join([obj, 'npy']))
    
    image = get_image(imgfile)
    mask = load_mask(maskfile)

    if self.transforms:
      transformed = self.transforms(image=image, mask=mask)
      image = transformed['image']
      mask = transformed['mask']

    mask = np.expand_dims(mask, 0)
        
    return image, mask, obj
      
  def __len__(self):
    return len(self.obj_list)

In [8]:
class SegDataModule(pl.LightningDataModule):
  """Initial DataModule for segmentation images.
    
    Args:
    
        anno_dir (str): dir with annotations
        img_dir: (str): dir with images
        mask_dir (str): dir with mask
        val_size (float): part for 
        test_size=0.1,
        batch_size=32,
        random_state=42
    
  
  """
  def __init__(self, 
               anno_dir: str = "./anno",
               img_dir: str = "./images",
               mask_dir: str = "./masks",
               transforms=None,
               val_size=0.2,
               test_size=0.1,
               batch_size=32,
               random_state=42):

    super().__init__()
    self.anno_dir = anno_dir
    self.img_dir = img_dir
    self.mask_dir = mask_dir
    self.transforms = transforms
    self.val_size = val_size
    self.test_size = test_size
    self.batch_size = batch_size
    self.random_state = random_state

  def prepare_data(self):

    jpgfiles = sorted(os.listdir(self.anno_dir))
    xmlfiles = sorted(os.listdir(self.img_dir))

    set_1 = set([f.split('.')[0] for f in jpgfiles])
    set_2 = set([f.split('.')[0] for f in xmlfiles])
    common_objects = set_1.intersection(set_2)

    os.makedirs(self.mask_dir, exist_ok=True)

    for obj in tqdm(common_objects):

      xmlfile = os.path.join(self.anno_dir, ".".join([obj, "xml"]))
      mask = get_mask(xmlfile)
      maskfile = os.path.join(self.img_dir, ".".join([obj, "npy"]))
      save_mask(maskfile, mask)

  def setup(self, stage=None):

    jpgfiles = sorted(os.listdir(self.anno_dir))
    xmlfiles = sorted(os.listdir(self.img_dir))

    set_1 = set([f.split('.')[0] for f in jpgfiles])
    set_2 = set([f.split('.')[0] for f in xmlfiles])
    common_objects = set_1.intersection(set_2)
    defect_types = set([obj.split('_')[0] for obj in common_objects])

    # split data per defects
    def split(defect_type):

      defect_objects = [obj for obj in common_objects if defect_type in obj]

      nontrain_part = 1 - (self.val_size+self.test_size)
      train, val = train_test_split(defect_objects,
                                    test_size=nontrain_part,
                                    random_state=self.random_state)
      
      val, test = train_test_split(val, 
                                   test_size=self.test_size/nontrain_part,
                                   random_state=self.random_state)
      
      return train, val, test

    self.train = []
    self.val = []
    self.test = []

    for defect_type in defect_types:
      train, val, test = split(defect_type)

      self.train.append(train)
      self.val.append(val)
      self.test.append(test)

    self.train = list(itertools.chain.from_iterable(self.train))
    self.traindataset = SegDataset(self.train, self.img_dir, self.mask_dir,
                                   transforms=self.transforms)

    self.val = list(itertools.chain.from_iterable(self.val))
    self.valdataset = SegDataset(self.val, self.img_dir, self.mask_dir,
                                 transforms=self.transforms)

    self.test = list(itertools.chain.from_iterable(self.test))
    self.testdataset = SegDataset(self.test, self.img_dir, self.mask_dir,
                                  transforms=self.transforms)

    self.alldataset = SegDataset(self.train+self.val+self.test,
                                 self.img_dir, self.mask_dir,
                                 transforms=self.transforms)

  def train_dataloader(self):
    return DataLoader(self.traindataset, batch_size=self.batch_size)

  def val_dataloader(self):
    return DataLoader(self.valdataset, batch_size=self.batch_size)

  def test_dataloader(self):
    return DataLoader(self.testdataset, batch_size=self.batch_size)

  def predict_dataloader(self):
    return DataLoader(self.alldataset, batch_size=1)


In [9]:
class PlSegModel(pl.LightningModule):
  """
  Some code is from https://github.com/Shreeyak/pytorch-lightning-segmentation-template
  """

  def __init__(self, model, loss, 
      train_metric, val_metric, test_metric,
      sample=None,
    ):

    super().__init__()
    self.model = model

    self.loss = loss

    self.train_metric = train_metric
    self.val_metric = val_metric
    self.test_metric = test_metric
    self.sample = sample
    self.monitor_imgs = self.sample

  def forward(self, x):
    """
      In lightning, forward defines the prediction/inference actions.
      This method can be called elsewhere in the LightningModule with: `outputs = self(inputs)`.
    """
      
    output_mask = self.model(x.permute(0,3,1,2))
    return output_mask
  
  def training_step(self, batch, batch_idx):
    """
      Defines the train loop. It is independent of forward().
      Don’t use any cuda or .to(device) calls in the code. 
      PL will move the tensors to the correct device.
    """

    inputs, mask, objects = batch
    outputs = self(inputs)

    pred = outputs
    target = mask
    loss = self.loss(pred, target.type_as(pred))
    self.train_metric(pred, target)
    
    self.log(f'train_metric', self.train_metric, 
             on_step=True, on_epoch=True, 
             prog_bar=False, logger=True, sync_dist=True)
    self.log(f'train_loss', loss, 
             on_step=True, on_epoch=True, 
             prog_bar=False, logger=True, sync_dist=True)

    return loss
  
  def validation_step(self, batch, batch_idx):
      
    inputs, mask, objects = batch
    outputs = self(inputs)
      
    pred = outputs
    target = mask
      
    loss = self.loss(pred, target.type_as(pred))
    self.val_metric(pred, target)

    self.log(f'val_metric', self.val_metric,
              on_step=True, on_epoch=True, 
              prog_bar=False, logger=True, sync_dist=True)
    self.log(f'val_loss', loss,
              on_step=True, on_epoch=True,
              prog_bar=False, logger=True, sync_dist=True)
  
  def test_step(self, batch, batch_idx):
      
    inputs, mask, objects = batch
    outputs = self(inputs)

    pred = outputs
    target = mask
    loss = self.loss(pred, target.type_as(pred))
    self.test_metric(pred, target)

    self.log(f'test_metric', self.test_metric,
              on_step=True, on_epoch=True,
              prog_bar=False, logger=True, sync_dist=True)
    self.log(f'test_loss', loss,
              on_step=True, on_epoch=True,
              prog_bar=False, logger=True, sync_dist=True)
  
  def configure_optimizers(self):

      optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
      lr_scheduler = ReduceLROnPlateau(optimizer, patience=3, verbose=True)
      
      sched = {'scheduler': lr_scheduler,
                'monitor': f'val_metric'}
      
      return [optimizer], [sched]

# Losses

In [10]:
class FocalLoss(torch.nn.Module):
    def __init__(self, gamma=2, alpha=None,
                 ignore_index=255, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.size_average = size_average
        self.ignore_index = ignore_index
        self.CE_loss = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=alpha)
        

    def forward(self, output, target):
        
        mask = target != self.ignore_index
        output = output * mask
        target = target * mask
        
        logpt = self.CE_loss(output, target)
        pt = torch.exp(-logpt)
        loss = ((1-pt)**self.gamma) * logpt
        if self.size_average:
            return loss.mean()
        return loss.sum()

# Metrics

In [11]:
def soft_dice_score(
    output: torch.Tensor, target: torch.Tensor,
    smooth: float = 1e-7, eps: float = 1e-7, dims=None
) -> torch.Tensor:
    """
    :param output:
    :param target:
    :param smooth:
    :param eps:
    :return:
    Shape:
        - Input: :math:`(N, NC, *)` where :math:`*` means any number
            of additional dimensions
        - Target: :math:`(N, NC, *)`, same shape as the input
        - Output: scalar.
    """
    assert output.size() == target.size()
    if dims is not None:
        intersection = torch.sum(output * target, dim=dims)
        cardinality = torch.sum(output + target, dim=dims)
    else:
        intersection = torch.sum(output * target)
        cardinality = torch.sum(output + target)
    dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
    return dice_score

class PLDice(pl.metrics.Metric):
    def __init__(self,  threshold: float = 0.5,
                 smooth: float = 1e-7, ignore_index=255,
                 compute_on_step=False,
                 dist_sync_on_step=False, channel_id=None):
      
        super().__init__(compute_on_step=compute_on_step,
                         dist_sync_on_step=dist_sync_on_step)
        self.threshold = threshold
        self.smooth = smooth
        self.ignore_index = ignore_index
        self.channel_id = channel_id
        
        self.add_state("correct", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor, names=None, imgs=None):
        preds = torch.sigmoid(preds)
        mask = target != self.ignore_index
        preds = preds * mask
        target = target * mask

        pred_bool = preds > self.threshold
        if self.channel_id is None:
            dice = soft_dice_score(pred_bool, target,
                                   smooth=self.smooth, dims=(2,3))
        else:
            i = self.channel_id
            dice = soft_dice_score(pred_bool[:,i:i+1,:,:], target[:,i:i+1,:,:],
                                   smooth=self.smooth, dims=(2,3))
        divisor = preds.shape[1] if self.channel_id is None else 1
        self.correct += dice.sum() / divisor # division into amount of object classes
        self.total += len(dice)

    def compute(self):
        return self.correct.float() / self.total

# Model

In [12]:
# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=ENCODER_WEIGHTS,     # use `imagenet` pre-trained weights for encoder initialization
    classes=NUMCLASSES,   # model output channels (number of classes in your dataset)
    in_channels=INPUTCHANNELS,
    activation=ACTIVATION
)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


  0%|          | 0.00/83.3M [00:00<?, ?B/s]

In [13]:
transforms = albu.Compose([
    albu.Normalize(),
    albu.Resize(256, 256)
])

In [14]:
seg_loss = FocalLoss()
seg_train_metric = PLDice(threshold=THRESHOLD, compute_on_step=True)
seg_valid_metric = PLDice(threshold=THRESHOLD, compute_on_step=True)
seg_test_metric = PLDice(threshold=THRESHOLD, compute_on_step=True)

  stream(template_mgs % msg_args)


In [15]:
pl_seg_model = PlSegModel(model, 
                          seg_loss,
                          seg_train_metric,
                          seg_valid_metric,
                          seg_test_metric)

In [16]:
pl_seg_datamodule = SegDataModule(ANNOFOLDER, 
                                  IMAGESFOLDER, 
                                  MASKFOLDER,
                                  transforms=transforms,
                                  val_size=VALSIZE,
                                  test_size=TESTSIZE, 
                                  random_state=RANDOM_STATE)

In [17]:
# pl_seg_datamodule.prepare_data()
# pl_seg_datamodule.setup()

In [18]:
model_type = model.__class__.__name__

today = date.today()
model_creation_date = today.strftime("%d_%m_%Y")
filename =  f'{model_type}_{model_creation_date}' +\
            '_{epoch:02d}_' + f'{seg_loss.__class__.__name__}_' +'{val_loss' + ':.6f}' +\
            f'_{seg_valid_metric.__class__.__name__}_' + '{val_metric' + ':.2f}'

pl_callbacks = [callbacks.ModelCheckpoint(
                            monitor=f'val_metric',
                            dirpath=MODELFOLDER,
                            verbose=True,
                            mode='max',
                            filename=filename
                        ),
                callbacks.EarlyStopping(monitor=f'val_metric',
                              min_delta=0.0,
                              patience=3,
                              verbose=True, 
                              mode='max', 
                              strict=True),
                callbacks.LearningRateMonitor(logging_interval='epoch')
    ]

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


In [19]:
seg_trainer = Trainer(gpus=[0], num_nodes=1, callbacks=pl_callbacks, 
                      auto_lr_find=False,
                      default_root_dir=LOGFOLDER, precision=32,
                      log_every_n_steps=2)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [20]:
seg_trainer.fit(pl_seg_model, pl_seg_datamodule)

  0%|          | 0/480 [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | Unet      | 24.4 M
1 | loss         | FocalLoss | 0     
2 | train_metric | PLDice    | 0     
3 | val_metric   | PLDice    | 0     
4 | test_metric  | PLDice    | 0     
-------------------------------------------
24.4 M    Trainable params
0         Non-trainable params
24.4 M    Total params
97.745    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Metric val_metric improved. New best score: 0.332
Epoch 0, global step 4: val_metric reached 0.33222 (best 0.33222), saving model to "/content/drive/MyDrive/data/Project_steel_plate_defect_detection_dataset/output/models/Unet_28_10_2021_epoch=00_FocalLoss_val_loss=0.183068_PLDice_val_metric=0.33.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Metric val_metric improved by 0.044 >= min_delta = 0.0. New best score: 0.376
Epoch 1, global step 9: val_metric reached 0.37628 (best 0.37628), saving model to "/content/drive/MyDrive/data/Project_steel_plate_defect_detection_dataset/output/models/Unet_28_10_2021_epoch=01_FocalLoss_val_loss=0.154104_PLDice_val_metric=0.38.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Metric val_metric improved by 0.062 >= min_delta = 0.0. New best score: 0.438
Epoch 2, global step 14: val_metric reached 0.43794 (best 0.43794), saving model to "/content/drive/MyDrive/data/Project_steel_plate_defect_detection_dataset/output/models/Unet_28_10_2021_epoch=02_FocalLoss_val_loss=0.140827_PLDice_val_metric=0.44.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Metric val_metric improved by 0.061 >= min_delta = 0.0. New best score: 0.499
Epoch 3, global step 19: val_metric reached 0.49852 (best 0.49852), saving model to "/content/drive/MyDrive/data/Project_steel_plate_defect_detection_dataset/output/models/Unet_28_10_2021_epoch=03_FocalLoss_val_loss=0.126741_PLDice_val_metric=0.50.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Metric val_metric improved by 0.033 >= min_delta = 0.0. New best score: 0.531
Epoch 4, global step 24: val_metric reached 0.53148 (best 0.53148), saving model to "/content/drive/MyDrive/data/Project_steel_plate_defect_detection_dataset/output/models/Unet_28_10_2021_epoch=04_FocalLoss_val_loss=0.113072_PLDice_val_metric=0.53.ckpt" as top 1


Epoch     5: reducing learning rate of group 0 to 1.0000e-05.


Validating: 0it [00:00, ?it/s]

Metric val_metric improved by 0.002 >= min_delta = 0.0. New best score: 0.534
Epoch 5, global step 29: val_metric reached 0.53381 (best 0.53381), saving model to "/content/drive/MyDrive/data/Project_steel_plate_defect_detection_dataset/output/models/Unet_28_10_2021_epoch=05_FocalLoss_val_loss=0.106719_PLDice_val_metric=0.53.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 6, global step 34: val_metric was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 7, global step 39: val_metric was not in top 1


Validating: 0it [00:00, ?it/s]

Monitored metric val_metric did not improve in the last 3 records. Best score: 0.534. Signaling Trainer to stop.
Epoch 8, global step 44: val_metric was not in top 1


Epoch     9: reducing learning rate of group 0 to 1.0000e-06.


In [21]:
result = seg_trainer.test()
print(result)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.10540266335010529,
 'test_loss_epoch': 0.10540266335010529,
 'test_metric': 0.5395669341087341,
 'test_metric_epoch': 0.5395669341087341}
--------------------------------------------------------------------------------
[{'test_metric': 0.5395669341087341, 'test_metric_epoch': 0.5395669341087341, 'test_loss': 0.10540266335010529, 'test_loss_epoch': 0.10540266335010529}]


In [22]:
# img, mask, obj = next(iter(pl_seg_datamodule.predict_dataloader()))
# img.shape, mask.shape, obj

In [23]:
# fig, ax = plot_image(img[0] * 255)
# plot_mask(mask[0,0], fig, ax)

In [24]:
# pred = model(img.permute(0,3,1,2))
# print(pred.shape, masks.shape)
# print(seg_loss(pred, masks.type_as(pred)))
# print(seg_train_metric(pred, masks))

In [25]:
# fig, ax = plot_image(img[0] * 255)
# plot_mask(pred[0,0].detach().numpy(), fig, ax)

In [26]:
# imgs, masks, obj = next(iter(pl_seg_datamodule.train_dataloader()))
# imgs.shape, masks.shape, obj