In [9]:
# train.py
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
from unetModel.unet_model import UNET
# from DiceLoss import DiceLoss
from unetModel.utils import (
    get_loaders_balloon,
    load_checkpoint,
    save_checkpoint,
    check_accuracy,
    save_predictions_as_imgs,
)
import os
import torch.nn as nn
import pytorch_lightning as pl
import torchvision

BASE_PATH = "./Balloons-1"

CHECKPOINT_PATH = "./model_cp/balloon_checkpoint.pth.tar"

# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# print("DEVICE", DEVICE)
BATCH_SIZE = 8
NUM_EPOCHS = 1
NUM_WORKERS = 2
IMAGE_HEIGHT = 160  # 1280 original
IMAGE_WIDTH = 240  # 1918 original
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = f"{BASE_PATH}/train/"
VAL_IMG_DIR = f"{BASE_PATH}/valid/"


In [11]:
train_transform = A.Compose(
    [
        A. Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A. Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)
val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)


In [5]:
model = UNET(in_channels=3, out_channels=1).to(DEVICE)


In [12]:
train_loaders, val_loaders = get_loaders_balloon(
    TRAIN_IMG_DIR,
    VAL_IMG_DIR,
    BATCH_SIZE,
    train_transform,
    val_transforms,
    NUM_WORKERS,
    PIN_MEMORY
)

In [13]:
trainer = pl.Trainer(
    accelerator="auto", 
    devices="auto",
    max_epochs=NUM_EPOCHS,
    # precision=16
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
trainer.fit(model, train_loaders, val_loaders)



  | Name        | Type               | Params
---------------------------------------------------
0 | downs       | ModuleList         | 4.7 M 
1 | ups         | ModuleList         | 12.2 M
2 | pool        | MaxPool2d          | 0     
3 | bottleneck  | DoubleConv         | 14.2 M
4 | final_conv  | Conv2d             | 65    
5 | loss_fn     | BCEWithLogitsLoss  | 0     
6 | accuracy    | MulticlassAccuracy | 0     
7 | my_accuracy | MyAccuracy         | 0     
---------------------------------------------------
31.0 M    Trainable params
0         Non-trainable params
31.0 M    Total params
124.151   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/eliaweiss/opt/anaconda3/envs/LayoutLMv3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

/Users/eliaweiss/opt/anaconda3/envs/LayoutLMv3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.
/Users/eliaweiss/opt/anaconda3/envs/LayoutLMv3/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 7/7 [01:24<00:00,  0.08it/s, v_num=29, loss=0.637, accuracy=0.750, dice_score=0.239] 

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 7/7 [01:25<00:00,  0.08it/s, v_num=29, loss=0.637, accuracy=0.750, dice_score=0.239]


In [22]:
loader = val_loaders
folder="./save_images"
device=DEVICE

In [17]:
for idx, (x, y) in enumerate(loader):
    x = x.to(device=device)
    with torch.no_grad():
        preds = torch.sigmoid(model(x))
        preds = (preds > 0.5).float()
    torchvision.utils.save_image(preds, os.path.join(folder,f"pred_{idx}.png"))
    torchvision.utils.save_image(y.unsqueeze(1), os.path.join(folder,f"correct_{idx}.png"))

FileNotFoundError: [Errno 2] No such file or directory: 'saved_images/pred_0.png'

In [23]:
torchvision.utils.save_image(preds, os.path.join(folder,f"pred_{idx}.png"))
