## 1. Setup

In [1]:
import torch
import dsutils
import easydict
import numpy as np
from glob import glob
from tqdm import tqdm
from PIL import Image
import lightning as L
import albumentations as A
from transformers import AutoImageProcessor, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg = easydict.EasyDict(
    model_name = '/hdd1/ysyoon/models/seggpt_vit-large/',
    batch_size = 4,
    num_training_steps = 100000
)

## 2. Data

In [3]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, dataset_names, augment, image_processor):
        self.root_dir = root_dir
        self.dataset_names = dataset_names
        self.augment = augment
        self.image_processor = image_processor

        self.num_datasets = len(self.dataset_names)
        self.image_paths = [sorted(glob(f'{root_dir}/{dname}/post/train/images/*')) for dname in dataset_names]
        self.num_images = [len(paths) for paths in self.image_paths]

    def __len__(self):
        return sum(self.num_images)

    def __getitem__(self, idx):
        ds_idx = np.random.randint(self.num_datasets)
        image_paths = self.image_paths[ds_idx]
        
        while True:
            image_path = image_paths[np.random.randint(self.num_images[ds_idx])]
            label_path = image_path.replace('/images/', '/labels/')
            
            image = np.array(Image.open(image_path))
            label = np.array(Image.open(label_path))

            unique_classes = np.unique(label)
            unique_classes = unique_classes[unique_classes>0]
            if len(unique_classes) > 0:
                break

        aug1 = self.augment(image=image, mask=label)
        image1, label1 = aug1['image'], aug1['mask']

        aug2 = self.augment(image=image, mask=label)
        image2, label2 = aug2['image'], aug2['mask']

        c = np.random.choice(unique_classes)
        label1 = np.where(label1==c, 1, 0)
        label2 = np.where(label2==c, 1, 0)
        
        palette = [[0,0,0], [255,255,255]]
        label1 = dsutils.segmentation.visualize_label(label1, palette)
        label2 = dsutils.segmentation.visualize_label(label2, palette)
        
        return image1, label1, image2, label2

    def collate_fn(self, batch):
        prompt_images, prompt_labels, input_images, input_labels = zip(*batch)
        batch = self.image_processor(prompt_images, prompt_labels, input_images, input_labels, return_tensors='pt')
        batch['mask'] = self.image_processor.generate_mask(len(prompt_images))
        return batch

In [4]:
augment = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ColorJitter(p=0.5),
    A.GaussNoise(p=0.5),
    A.Blur(p=0.5),
])

root_dir = '/hdd1/ysyoon/datasets/'
dataset_names = [
    'aihub-landcover-satellite-all',
    'aihub-satellite-object-cloud',
    'eorssd',
    'isaid',
    'open-earth-map',
    'orsi-sod'
]

In [5]:
image_processor = AutoImageProcessor.from_pretrained(cfg.model_name, trust_remote_code=True)
dataset = Dataset(root_dir, dataset_names, augment, image_processor)

In [6]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=dataset.collate_fn)

In [7]:
batch = next(iter(dataloader))
{k:v.shape for k,v in batch.items()}

{'images': torch.Size([4, 3, 896, 448]),
 'labels': torch.Size([4, 3, 896, 448]),
 'mask': torch.Size([4, 56, 28])}

## 3. Train

In [8]:
model = AutoModel.from_pretrained(cfg.model_name, trust_remote_code=True)

In [9]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

In [10]:
pbar = tqdm(range(1, cfg.num_training_steps+1))
dataiter = iter(dataloader)
for st in pbar:
    try:
        batch = next(dataiter)
    except StopIteration:
        dataiter = iter(dataloader)
        batch = next(dataiter)

    batch = dsutils.

SyntaxError: invalid syntax (2993102005.py, line 10)

In [11]:
batch.to()

{'images': tensor([[[[-0.5082, -0.5596, -0.4054,  ..., -0.3369, -0.2684, -0.2171],
          [-0.6109, -0.5767, -0.5082,  ..., -0.3541, -0.3712, -0.4054],
          [-0.6281, -0.6965, -0.6623,  ..., -0.3027, -0.4911, -0.5253],
          ...,
          [-1.2445, -1.1418, -0.9877,  ..., -0.5767, -0.4397, -0.3541],
          [-0.9877, -0.8849, -0.7993,  ..., -0.7308, -0.5424, -0.4226],
          [-0.7822, -0.6965, -0.6452,  ..., -0.7993, -0.5767, -0.4054]],

         [[-0.3200, -0.3901, -0.4251,  ..., -0.2150, -0.2675, -0.2850],
          [-0.3725, -0.4776, -0.5301,  ..., -0.2850, -0.3725, -0.4601],
          [-0.5651, -0.5651, -0.6877,  ..., -0.2675, -0.4076, -0.5826],
          ...,
          [-1.1429, -1.0203, -0.8627,  ..., -0.5126, -0.3901, -0.3025],
          [-0.8978, -0.7927, -0.6527,  ..., -0.6702, -0.4951, -0.3725],
          [-0.6877, -0.6001, -0.5476,  ..., -0.7052, -0.5126, -0.3550]],

         [[-0.4973, -0.4275, -0.4450,  ..., -0.3404, -0.3578, -0.3404],
          [-0.5321,

In [12]:
outputs = model(**batch)

TypeError: SegGPTModel.forward() got an unexpected keyword argument 'mask'

In [13]:
outputs.loss

NameError: name 'outputs' is not defined