In [1]:
!git clone https://github.com/milesial/Pytorch-UNet

Cloning into 'Pytorch-UNet'...
remote: Enumerating objects: 618, done.[K
remote: Total 618 (delta 0), reused 0 (delta 0), pack-reused 618 (from 1)[K
Receiving objects: 100% (618/618), 47.42 MiB | 25.95 MiB/s, done.
Resolving deltas: 100% (334/334), done.


In [2]:
%cd Pytorch-UNet

/content/Pytorch-UNet


In [41]:
from unet import UNet
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(n_channels=8, n_classes=1)
model.to(device)

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [42]:
import tifffile
from torch.utils.data import Dataset
import numpy as np

class TiffDataset(Dataset):
    def __init__(self, images_dir, masks_dir, scale=1.0):
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)
        self.scale = scale
        self.ids = [f.stem for f in self.images_dir.glob("*.tif")]

    def __getitem__(self, idx):
        name = self.ids[idx]
        img = tifffile.imread(str(self.images_dir / f"{name}.tif")).astype(np.float32) / 255.0  # [H, W, C]
        mask = tifffile.imread(str(self.masks_dir / f"{name}.tif")).astype(np.uint8)

        if img.ndim == 3:
            img = img.transpose(2, 0, 1)  # to [C, H, W]
        if mask.ndim == 3:
            mask = mask[:, :, 0]  # in case of RGB mask

        mask = tifffile.imread(str(self.masks_dir / f"{name}.tif")).astype(np.uint8)

        # Ensure binary 0/1
        if mask.max() > 1:
            mask = (mask > 127).astype(np.uint8)  # or use threshold of your choice

        mask = torch.from_numpy(mask).float()  # Use float to match loss expectations

        return {
            'image': torch.from_numpy(img),
            'mask': mask  # Already float
        }

    def __len__(self):
        return len(self.ids)


In [43]:
# Alle Gewichte einfrieren
for param in model.parameters():
    param.requires_grad = False

# Nur die Parameter der letzten Ebene(n) trainierbar machen
for param in model.outc.parameters():  # 'outc' ist meist der finale Conv-Layer im UNet
    param.requires_grad = True

In [49]:
!pip install importlib

Collecting importlib
  Downloading importlib-1.0.4.zip (7.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: importlib
  Building wheel for importlib (setup.py) ... [?25l[?25hdone
  Created wheel for importlib: filename=importlib-1.0.4-py3-none-any.whl size=5850 sha256=a4447fbd4da6c18a0910290879138f9b8eb655667a817fb66d63b5d0c04177e1
  Stored in directory: /root/.cache/pip/wheels/03/4a/6e/7c4a313549653a504574fa29f907139c752051ef05210df605
Successfully built importlib
Installing collected packages: importlib
Successfully installed importlib-1.0.4


In [56]:
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

import wandb
from evaluate import evaluate
from unet import UNet
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss

dir_checkpoint = Path('./checkpoints/')

import importlib
import utils.dice_score
importlib.reload(utils.dice_score)
import evaluate
importlib.reload(evaluate)
from evaluate import evaluate
from utils.dice_score import dice_loss


def train_model(
        model,
        device,
        epochs: int = 5,
        batch_size: int = 1,
        learning_rate: float = 1e-5,
        val_percent: float = 0.1,
        save_checkpoint: bool = True,
        img_scale: float = 0.5,
        amp: bool = False,
        weight_decay: float = 1e-8,
        momentum: float = 0.999,
        gradient_clipping: float = 1.0,
):
    dataset = TiffDataset('/content/drive/MyDrive/Brick_Data_Train/Brick_Data_Train/Image', '/content/drive/MyDrive/Brick_Data_Train/Brick_Data_Train/Mask')


    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment.config.update(
        dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
             val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
    )

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()),
                              lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
    global_step = 0

    # 5. Begin training
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images, true_masks = batch['image'], batch['mask']

                assert images.shape[1] == model.n_channels, \
                    f'Network has been defined with {model.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
                    masks_pred = model(images)
                    if model.n_classes == 1:
                        loss = criterion(masks_pred.squeeze(1), true_masks.float())
                        loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
                    else:
                        loss = criterion(masks_pred, true_masks)
                        loss += dice_loss(
                            F.softmax(masks_pred, dim=1).float(),
                            F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                            multiclass=True
                        )

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                division_step = (n_train // (5 * batch_size))
                if division_step > 0:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in model.named_parameters():
                            if not value.requires_grad:
                                continue  # skip frozen layers
                            tag = tag.replace('/', '.')
                            if not (torch.isinf(value) | torch.isnan(value)).any():
                                histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                            if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
                                histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                        val_score = evaluate(model, val_loader, device, amp)
                        scheduler.step(val_score)

                        logging.info('Validation Dice score: {}'.format(val_score))
                        try:
                            experiment.log({
                                'learning rate': optimizer.param_groups[0]['lr'],
                                'validation Dice': val_score,
                                'images': wandb.Image(images[0].cpu()),
                                'masks': {
                                    'true': wandb.Image(true_masks[0].float().cpu()),
                                    'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
                                },
                                'step': global_step,
                                'epoch': epoch,
                                **histograms
                            })
                        except:
                            pass

        if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            state_dict = model.state_dict()
            torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
            logging.info(f'Checkpoint {epoch} saved!')


def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')

    return parser.parse_args()


train_model(model=model, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

  grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
Epoch 1/5:  15%|█▌        | 2/13 [00:04<00:25,  2.35s/img, loss (batch)=nan]
Validation round:   0%|          | 0/1 [00:00<?, ?batch/s][A
Validation round: 100%|██████████| 1/1 [00:02<00:00,  2.07s/batch][A
Epoch 1/5:  31%|███       | 4/13 [00:11<00:23,  2.66s/img, loss (batch)=nan]
Validation round:   0%|          | 0/1 [00:00<?, ?batch/s][A
Validation round: 100%|██████████| 1/1 [00:03<00:00,  3.06s/batch][A
Epoch 1/5:  46%|████▌     | 6/13 [00:19<00:22,  3.24s/img, loss (batch)=nan]
Validation round:   0%|          | 0/1 [00:00<?, ?batch/s][A
Validation round: 100%|██████████| 1/1 [00:02<00:00,  2.12s/batch][A
Epoch 1/5:  62%|██████▏   | 8/13 [00:25<00:15,  3.07s/img, loss (batch)=nan]
Validation round:   0%|          | 0/1 [00:00<?, ?batch/s][A
Validation round: 100%|██████████| 1/1 [00:03<00:00,  3.08s/batch][A
Epoch 1/5:  77%|███████▋  | 10/13 [00:33<00:10,  3.34s/img, loss (batch)=nan]
Validation round:   0%|        

KeyboardInterrupt: 