# Preliminaries

In [None]:
%%capture
!pip install wandb

In [None]:
#!wget -q https://www.dropbox.com/s/mk6w88fzn6n9eak/dataset2.zip
#!unzip -q dataset2.zip
#!rm dataset2.zip

In [None]:
import os
if not os.path.exists('dataset2'):
    !wget -q https://www.dropbox.com/s/mk6w88fzn6n9eak/dataset2.zip
    !unzip -q dataset2.zip
    !rm dataset2.zip

!pip install -q torch_snippets pytorch_model_summary

from torch_snippets import *
from torchvision import transforms
from sklearn.model_selection import train_test_split

from torchvision.models import vgg16_bn

from tqdm import tqdm

# login no wandb é feito através de api, quando rodar ele vai pedir o código
import wandb
wandb.login()

unzip:  cannot find or open dataset1.zip, dataset1.zip.zip or dataset1.zip.ZIP.
rm: cannot remove 'dataset1.zip': No such file or directory
[K     |████████████████████████████████| 54 kB 2.5 MB/s 
[K     |████████████████████████████████| 78 kB 6.6 MB/s 
[K     |████████████████████████████████| 237 kB 78.2 MB/s 
[K     |████████████████████████████████| 58 kB 6.1 MB/s 
[K     |████████████████████████████████| 1.6 MB 72.2 MB/s 
[K     |████████████████████████████████| 174 kB 50.8 MB/s 
[K     |████████████████████████████████| 2.2 MB 58.6 MB/s 
[K     |████████████████████████████████| 51 kB 6.9 MB/s 
[?25h  Building wheel for sklearn (setup.py) ... [?25l[?25hdone
  Building wheel for typing (setup.py) ... [?25l[?25hdone


ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

# Configuration

In [None]:
class config:
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    LEARNING_RATE = 1e-3
    N_EPOCHS = 100

# wandb config
WANDB_CONFIG = {'_wandb_kernel': 'neuracort'}

# inicializando projeto no W&B
run = wandb.init(
    project='semantic_segmentation_unet', 
    config= WANDB_CONFIG
)

[34m[1mwandb[0m: Currently logged in as: [33mnetohd97[0m ([33mdataservices[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Transformations

In [None]:
def get_transforms():
  return transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize(
                                 [0.485, 0.456, 0.406], 
                                 [0.229, 0.224, 0.225]
                                 )
                             ])

# Dataset Class

In [None]:
class SegmentationData(Dataset):
    def __init__(self, split):
        self.items = stems(f'dataset1/images_prepped_{split}')
        self.split = split

    def __len__(self):
        return len(self.items)

    def __getitem__(self, ix):
        image = read(f'dataset1/images_prepped_{self.split}/{self.items[ix]}.png', 1)
        image = cv2.resize(image, (224,224))

        mask = read(f'dataset1/annotations_prepped_{self.split}/{self.items[ix]}.png')
        mask = cv2.resize(mask, (224,224))

        # augmentation aqui
        # lembrar de aplicar tanto p/ mask quanto pra imagem

        return image, mask

    def choose(self): return self[randint(len(self))]
    
    def collate_fn(self, batch):
        ims, masks = list(zip(*batch))

        ims = torch.cat([get_transforms()(im.copy()/255.)[None] for im in ims]).float().to(config.DEVICE)

        ce_masks = torch.cat([torch.Tensor(mask[None]) for mask in masks]).long().to(config.DEVICE)

        return ims, ce_masks

In [None]:
def get_dataloaders():
  trn_ds = SegmentationData('train')
  val_ds = SegmentationData('test')

  trn_dl = DataLoader(trn_ds, batch_size=16, shuffle=True, collate_fn=trn_ds.collate_fn)
  val_dl = DataLoader(val_ds, batch_size=1, shuffle=True, collate_fn=val_ds.collate_fn)

  return trn_dl, val_dl

In [None]:
trn_dl, val_dl = get_dataloaders()

# U Net Architecture

In [None]:
def conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

In [None]:
def up_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
        nn.ReLU(inplace=True)
    )

In [None]:
class UNet(nn.Module):
    def __init__(self, pretrained=True, out_channels=12):
        super().__init__()

        self.encoder = vgg16_bn(pretrained=pretrained).features
        self.block1 = nn.Sequential(*self.encoder[:6])
        self.block2 = nn.Sequential(*self.encoder[6:13])
        self.block3 = nn.Sequential(*self.encoder[13:20])
        self.block4 = nn.Sequential(*self.encoder[20:27])
        self.block5 = nn.Sequential(*self.encoder[27:34])

        self.bottleneck = nn.Sequential(*self.encoder[34:])
        self.conv_bottleneck = conv(512, 1024)

        self.up_conv6 = up_conv(1024, 512)
        self.conv6 = conv(512 + 512, 512)
        self.up_conv7 = up_conv(512, 256)
        self.conv7 = conv(256 + 512, 256)
        self.up_conv8 = up_conv(256, 128)
        self.conv8 = conv(128 + 256, 128)
        self.up_conv9 = up_conv(128, 64)
        self.conv9 = conv(64 + 128, 64)
        self.up_conv10 = up_conv(64, 32)
        self.conv10 = conv(32 + 64, 32)
        self.conv11 = nn.Conv2d(32, out_channels, kernel_size=1)
        
    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)

        bottleneck = self.bottleneck(block5)
        x = self.conv_bottleneck(bottleneck)

        x = self.up_conv6(x)
        x = torch.cat([x, block5], dim=1)
        x = self.conv6(x)

        x = self.up_conv7(x)
        x = torch.cat([x, block4], dim=1)
        x = self.conv7(x)

        x = self.up_conv8(x)
        x = torch.cat([x, block3], dim=1)
        x = self.conv8(x)

        x = self.up_conv9(x)
        x = torch.cat([x, block2], dim=1)
        x = self.conv9(x)

        x = self.up_conv10(x)
        x = torch.cat([x, block1], dim=1)
        x = self.conv10(x)

        x = self.conv11(x)

        return x

# Loss Function

In [None]:
ce = nn.CrossEntropyLoss()

def UnetLoss(preds, targets):
    ce_loss = ce(preds, targets)
    acc = (torch.max(preds, 1)[1] == targets).float().mean()
    return ce_loss, acc

# Engine

In [None]:
class engine():
  def train_batch(model, data, optimizer, criterion):
      model.train()

      ims, ce_masks = data
      _masks = model(ims)
      optimizer.zero_grad()

      loss, acc = criterion(_masks, ce_masks)
      loss.backward()
      optimizer.step()

      return loss.item(), acc.item()

  @torch.no_grad()
  def validate_batch(model, data, criterion):
      model.eval()

      ims, masks = data
      _masks = model(ims)

      loss, acc = criterion(_masks, masks)

      return loss.item(), acc.item()

In [None]:
def make_model():
  model = UNet().to(config.DEVICE)
  criterion = UnetLoss
  optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
  return model, criterion, optimizer

In [None]:
model, criterion, optimizer = make_model()

Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

# Train

In [None]:
def run():
  for epoch in range(config.N_EPOCHS):
      print("####################")
      print(f"       Epoch: {epoch}   ")
      print("####################")

      for bx, data in tqdm(enumerate(trn_dl), total = len(trn_dl)):
          train_loss, train_acc = engine.train_batch(model, data, optimizer, criterion)

      for bx, data in tqdm(enumerate(val_dl), total = len(val_dl)):
          val_loss, val_acc = engine.validate_batch(model, data, criterion)

      wandb.log(
          {   
              'epoch': epoch,
              'train_loss': train_loss,
              'train_acc': train_acc,
              'val_loss': val_loss,
              'val_acc': val_acc
          }
      )

      print()

In [None]:
run()

100%|██████████| 5/5 [02:39<00:00, 31.85s/it]
100%|██████████| 16/16 [00:12<00:00,  1.33it/s]


100%|██████████| 5/5 [02:38<00:00, 31.62s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:34<00:00, 30.87s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:37<00:00, 31.42s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:34<00:00, 30.93s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:40<00:00, 32.08s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:38<00:00, 31.60s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.55s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:36<00:00, 31.38s/it]
100%|██████████| 16/16 [00:13<00:00,  1.21it/s]


100%|██████████| 5/5 [02:40<00:00, 32.18s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:37<00:00, 31.43s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:39<00:00, 31.88s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:40<00:00, 32.10s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:39<00:00, 31.98s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.44s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:38<00:00, 31.75s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:40<00:00, 32.14s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:38<00:00, 31.70s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:36<00:00, 31.34s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:44<00:00, 32.81s/it]
100%|██████████| 16/16 [00:12<00:00,  1.28it/s]


100%|██████████| 5/5 [02:35<00:00, 31.20s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:38<00:00, 31.71s/it]
100%|██████████| 16/16 [00:12<00:00,  1.32it/s]


100%|██████████| 5/5 [02:35<00:00, 31.07s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:41<00:00, 32.38s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:36<00:00, 31.22s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:38<00:00, 31.70s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:39<00:00, 31.87s/it]
100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


100%|██████████| 5/5 [02:36<00:00, 31.31s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:38<00:00, 31.68s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:35<00:00, 31.10s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:41<00:00, 32.33s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:35<00:00, 31.07s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:36<00:00, 31.37s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:35<00:00, 31.12s/it]
100%|██████████| 16/16 [00:14<00:00,  1.09it/s]


100%|██████████| 5/5 [02:38<00:00, 31.71s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:35<00:00, 31.00s/it]
100%|██████████| 16/16 [00:12<00:00,  1.32it/s]


100%|██████████| 5/5 [02:36<00:00, 31.40s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:41<00:00, 32.22s/it]
100%|██████████| 16/16 [00:13<00:00,  1.18it/s]


100%|██████████| 5/5 [02:35<00:00, 31.09s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:37<00:00, 31.50s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:35<00:00, 31.08s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:42<00:00, 32.40s/it]
100%|██████████| 16/16 [00:12<00:00,  1.32it/s]


100%|██████████| 5/5 [02:35<00:00, 31.02s/it]
100%|██████████| 16/16 [00:12<00:00,  1.32it/s]


100%|██████████| 5/5 [02:37<00:00, 31.43s/it]
100%|██████████| 16/16 [00:12<00:00,  1.33it/s]


100%|██████████| 5/5 [02:37<00:00, 31.55s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:37<00:00, 31.56s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:37<00:00, 31.41s/it]
100%|██████████| 16/16 [00:12<00:00,  1.33it/s]


100%|██████████| 5/5 [02:34<00:00, 30.97s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:42<00:00, 32.53s/it]
100%|██████████| 16/16 [00:12<00:00,  1.32it/s]


100%|██████████| 5/5 [02:34<00:00, 31.00s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:36<00:00, 31.36s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:37<00:00, 31.58s/it]
100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


100%|██████████| 5/5 [02:36<00:00, 31.26s/it]
100%|██████████| 16/16 [00:12<00:00,  1.32it/s]


100%|██████████| 5/5 [02:36<00:00, 31.33s/it]
100%|██████████| 16/16 [00:12<00:00,  1.32it/s]


100%|██████████| 5/5 [02:34<00:00, 30.91s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:41<00:00, 32.34s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:35<00:00, 31.12s/it]
100%|██████████| 16/16 [00:12<00:00,  1.26it/s]


100%|██████████| 5/5 [02:36<00:00, 31.38s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:37<00:00, 31.52s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:35<00:00, 31.02s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:37<00:00, 31.56s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:35<00:00, 31.14s/it]
100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


100%|██████████| 5/5 [02:39<00:00, 31.81s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:38<00:00, 31.65s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:43<00:00, 32.64s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:35<00:00, 31.15s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:38<00:00, 31.71s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:37<00:00, 31.54s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:40<00:00, 32.02s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:38<00:00, 31.67s/it]
100%|██████████| 16/16 [00:12<00:00,  1.28it/s]


100%|██████████| 5/5 [02:36<00:00, 31.27s/it]
100%|██████████| 16/16 [00:12<00:00,  1.27it/s]


100%|██████████| 5/5 [02:37<00:00, 31.59s/it]
100%|██████████| 16/16 [00:12<00:00,  1.28it/s]


100%|██████████| 5/5 [02:39<00:00, 31.83s/it]
100%|██████████| 16/16 [00:12<00:00,  1.27it/s]


100%|██████████| 5/5 [02:37<00:00, 31.55s/it]
100%|██████████| 16/16 [00:12<00:00,  1.26it/s]


100%|██████████| 5/5 [02:39<00:00, 31.94s/it]
100%|██████████| 16/16 [00:12<00:00,  1.23it/s]


100%|██████████| 5/5 [02:40<00:00, 32.12s/it]
100%|██████████| 16/16 [00:12<00:00,  1.27it/s]


100%|██████████| 5/5 [02:36<00:00, 31.35s/it]
100%|██████████| 16/16 [00:14<00:00,  1.10it/s]


100%|██████████| 5/5 [02:35<00:00, 31.19s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.44s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:34<00:00, 30.98s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:38<00:00, 31.62s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.58s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.58s/it]
100%|██████████| 16/16 [00:12<00:00,  1.28it/s]


100%|██████████| 5/5 [02:37<00:00, 31.57s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.52s/it]
100%|██████████| 16/16 [00:12<00:00,  1.28it/s]


100%|██████████| 5/5 [02:37<00:00, 31.58s/it]
100%|██████████| 16/16 [00:12<00:00,  1.28it/s]


100%|██████████| 5/5 [02:37<00:00, 31.52s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.60s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:35<00:00, 31.02s/it]
100%|██████████| 16/16 [00:13<00:00,  1.17it/s]


100%|██████████| 5/5 [02:35<00:00, 31.05s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.45s/it]
100%|██████████| 16/16 [00:12<00:00,  1.30it/s]


100%|██████████| 5/5 [02:34<00:00, 30.86s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.50s/it]
100%|██████████| 16/16 [00:12<00:00,  1.28it/s]


100%|██████████| 5/5 [02:37<00:00, 31.56s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:34<00:00, 30.98s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.47s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:37<00:00, 31.56s/it]
100%|██████████| 16/16 [00:12<00:00,  1.29it/s]


100%|██████████| 5/5 [02:36<00:00, 31.21s/it]
100%|██████████| 16/16 [00:12<00:00,  1.28it/s]


100%|██████████| 5/5 [02:38<00:00, 31.71s/it]
100%|██████████| 16/16 [00:12<00:00,  1.28it/s]


100%|██████████| 5/5 [02:40<00:00, 32.05s/it]
100%|██████████| 16/16 [00:12<00:00,  1.28it/s]


In [None]:
torch.save(model.state_dict(), 'checkpointv3.pth')

# download do checkpoint
#files.download('checkpoint.pth')

# Predictions

In [None]:
def save_table(table_name):
  table = wandb.Table(columns=['Original Image', 'Original Mask', 'Predicted Mask'], allow_mixed_types = True)

  for bx, data in tqdm(enumerate(val_dl), total = len(val_dl)):
    im, mask = data
    _mask = model(im)
    _, _mask = torch.max(_mask, dim=1)

    plt.figure(figsize=(10,10))
    plt.axis("off")
    plt.imshow(im[0].permute(1,2,0).detach().cpu()[:,:,0])
    plt.savefig("original_image.jpg")
    plt.close()

    plt.figure(figsize=(10,10))
    plt.axis("off")
    plt.imshow(mask.permute(1,2,0).detach().cpu()[:,:,0])
    plt.savefig("original_mask.jpg")
    plt.close()

    plt.figure(figsize=(10,10))
    plt.axis("off")
    plt.imshow(_mask.permute(1,2,0).detach().cpu()[:,:,0])
    plt.savefig("predicted_mask.jpg")
    plt.close()

    table.add_data(
        wandb.Image(cv2.cvtColor(cv2.imread("original_image.jpg"), cv2.COLOR_BGR2RGB)),
        wandb.Image(cv2.cvtColor(cv2.imread("original_mask.jpg"), cv2.COLOR_BGR2RGB)),
        wandb.Image(cv2.cvtColor(cv2.imread("predicted_mask.jpg"), cv2.COLOR_BGR2RGB))
    )

  wandb.log({table_name: table})

In [None]:
save_table("PredictionsV3")

100%|██████████| 16/16 [00:23<00:00,  1.49s/it]


In [None]:
while True:pass