In [1]:
from multipage_classifier.dataset import UCSFDataset, UCSFDataModule
from multipage_classifier.preprocessor import ImageProcessor
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
NAME = "multipage_classifier"
DATASET="../dataset/ucsf-idl-resized/"
N_EPOCHS = 50
BATCH_SIZE = 2

IMAGE_SIZE  = (512, 1024)

In [3]:
from transformers import SwinModel, SwinConfig

config = SwinConfig(
    image_size=IMAGE_SIZE,
    depths=[2, 2, 14, 2],
    window_size=8,
    patch_size=8,
    embed_dim=96,
    num_heads=[4, 8, 16, 32],
    num_classes=0
)
swin_encoder = SwinModel(config)

In [4]:
img_w, img_h = swin_encoder.config.image_size

patch_size = swin_encoder.config.patch_size
patch_size_stage_4 = patch_size * 2 * 2 * 2
patches_stage_4 = (img_w /  patch_size_stage_4 )* (img_h / patch_size_stage_4)

emb_dim = swin_encoder.config.embed_dim
emb_dim_stage_4 = emb_dim * 2 * 2 * 2

encoder_output_size = int(patches_stage_4 * emb_dim_stage_4)
print(encoder_output_size)

98304


In [5]:
from multipage_classifier.encoder import EncoderForEmbedding

encoder = EncoderForEmbedding(swin_encoder, encoder_output_size)

In [6]:
from multipage_classifier.preprocessor import ImageProcessor
image_processor = ImageProcessor(img_size=IMAGE_SIZE)

data_module = UCSFDataModule(
    Path(DATASET),
    prepare_function=image_processor.prepare_input,
    split=[0.8, 0.2],
    batch_size=8,
    num_workers=1
)

In [7]:
import json
from torch.utils.data import DataLoader, Dataset, default_collate
from multipage_classifier.dataset import collate, val_collate
dataset_path = Path(DATASET)

labels = json.load(open(dataset_path / "labels.json"))

#random.shuffle(labels)

classes = []
for l in labels:
    if l["type"] not in classes:
        classes.append(l["type"])

train_dataset = UCSFDataset(
            dataset_path, labels, classes, image_processor.prepare_input
        )
dl = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
    collate_fn=collate,
)

collate([train_dataset[0], train_dataset[1]])["page_nr"]

tensor([0, 0, 1])

In [8]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl

# Configure checkpointing
checkpoint_callback = ModelCheckpoint(
    filename="best-checkpoint-{epoch:02d}-{val_loss:.4f}",
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode="min",
    save_last=True,
)
checkpoint_callback.CHECKPOINT_NAME_LAST= "checkpoint-{epoch:02d}-{val_loss:.4f}"

logger = TensorBoardLogger("lightning_logs", name=NAME)

trainer = pl.Trainer(accelerator="gpu", devices=[0,1], logger=logger, max_epochs=N_EPOCHS, gpus=1)

trainer.fit(encoder, data_module)

  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type          | Params
------------------------------------------
0 | encoder | SwinModel     | 41.8 M
1 | heads   | ModuleDict    | 393 K 
2 | weights | ParameterDict | 4     
------------------------------------------
42.2 M    Trainable params
4         Non-trainable params
42.2 M    Total params
168.676   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/59 [00:00<?, ?it/s] 

OutOfMemoryError: CUDA out of memory. Tried to allocate 1020.00 MiB (GPU 0; 23.69 GiB total capacity; 5.33 GiB already allocated; 843.88 MiB free; 6.26 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
import json
import os
from PIL import Image
from pathlib import Path
from tqdm import tqdm
labels = json.load(open(Path(DATASET) / "labels.json"))

for label in tqdm(labels):
    img_folder = label["image_folder"]
    for i, file in enumerate(os.listdir(Path(DATASET) / img_folder)):

        path =  Path(DATASET) / img_folder / file
        x = image_processor.prepare_input(Image.open(path)),