In [207]:
%load_ext autoreload
%autoreload 2

from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt 
from glob import glob
import os

import torch
import torch.nn as nn 
import torch.nn.functional as F
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode
from tensorflow.keras.utils import to_categorical

#---------------------------------------
import custom_dataset
import utils_rs

#----------------------------------------------
import swin
import upper_net_mmseg
import models


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [208]:
# model swin
model_swin = swin.swin()

# model upernet
model_upernet = upper_net_mmseg.UPerHead(
    in_channels = model_swin.out_channels[1:],
    channels = model_swin.out_channels[2],
    in_index= (0,1,2,3),
    dropout_ratio=0.1,
    #norm_cfg= dict(type='SyncBN', requires_grad=True)
)


# model samrs 
swin_samrs = models.SamRS(model1=model_swin, model2=model_upernet)

In [209]:

ISAID_CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond',
               'tennis_court', 'basketball_court', 'Ground_Track_Field',
               'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
               'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
               'Harbor')

ISAID_PALETTE = {
    0: (0, 0, 0), 1: (0, 0, 63), 2: (0, 63, 63), 3: (0, 63, 0), 4: (0, 63, 127),
    5: (0, 63, 191), 6: (0, 63, 255), 7: (0, 127, 63), 8: (0, 127, 127),
    9: (0, 0, 127), 10: (0, 0, 191), 11: (0, 0, 255), 12: (0, 191, 127),
    13: (0, 127, 191), 14: (0, 127, 255), 15: (0, 100, 155)}


In [210]:
img_path = "/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/04.1024_imgs"
mask_path = "/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/08.1024_masks_categorized_imgs_all"

assert len(os.listdir(img_path)) == len(os.listdir(mask_path))

segdataset = custom_dataset.SegDataset(img_path, mask_path, 512)

['/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/04.1024_imgs/P0000_0_0.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/04.1024_imgs/P0000_0_1.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/04.1024_imgs/P0000_0_2.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/04.1024_imgs/P0000_1_0.png']
['/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/08.1024_masks_categorized_imgs_all/0.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/08.1024_masks_categorized_imgs_all/1.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/08.1024_masks_categorized_imgs_all/2.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/08.1024_masks_categorized_imgs_all/3.png']


In [211]:
a, b = segdataset.__getitem__(0)

In [212]:
dummy_ = torch.randn((1,3,512,512))

In [213]:
dummy_.shape

torch.Size([1, 3, 512, 512])

In [214]:
a_ = a.unsqueeze(0)
out_ = swin_samrs(a_)

In [215]:
out_[0].shape

torch.Size([18, 512, 512])

In [216]:
b.shape

torch.Size([16, 512, 512])

In [152]:
swin_samrs.semseghead_1

Sequential(
  (0): Dropout2d(p=0.1, inplace=False)
  (1): Conv2d(192, 18, kernel_size=(1, 1), stride=(1, 1))
)

In [77]:
model_swin.out_channels[2]

192

In [153]:
# #         self.semseghead_1 = nn.Sequential(
#                 nn.Dropout2d(0.1),
#                 nn.Conv2d(self.encoder.out_channels[2], classes1, kernel_size=1)
#             )

# fine tune layer initialize

n_classes = 16
swin_samrs.semseghead_1 = nn.Sequential(
                                    nn.Dropout2d(0.1),
                                    nn.Conv2d(model_swin.out_channels[2], n_classes, kernel_size=1)
                                    )

In [154]:
swin_samrs.semseghead_1

Sequential(
  (0): Dropout2d(p=0.1, inplace=False)
  (1): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))
)

In [155]:
out_ = swin_samrs(dummy_)
print(out_[0].shape)

torch.Size([16, 512, 512])


In [156]:
b.shape

torch.Size([16, 512, 512])

In [123]:
pred = out_[0].unsqueeze(0)

In [126]:
pred.shape

torch.Size([1, 16, 512, 512])

In [127]:
target = b.unsqueeze(0)

In [128]:
target.shape

torch.Size([1, 16, 512, 512])

In [129]:
import loss_
loss_fn = loss_.DiceLoss()


In [130]:
qwe = loss_fn(pred, target)

In [131]:
qwe

tensor(1.0717, grad_fn=<RsubBackward1>)

In [81]:
dummy_ = dummy_.to("cuda:0")

# load
w_path = "./swint_upernet_imp_sep_model.pth"
weights_ = torch.load(w_path, map_location=torch.device('cpu'))
swin_samrs.load_state_dict(  weights_['state_dict'] )

In [82]:
swin_samrs = swin_samrs.to("cuda:0")

In [53]:
out_ = swin_samrs(dummy_)

In [78]:
import loss_

In [79]:
diceloss = loss_.DiceLoss()

In [8]:
#-----------------------------------------

import wandb
import logging
from tqdm import tqdm
from lightning.fabric import Fabric
import lightning as L 

In [16]:
# # distribute gpus
# device_list = [0,1,2,3]
# fabric = L.Fabric(accelerator="cuda", devices=device_list, strategy='ddp')
# fabric.launch()

In [17]:
# log 
logging.basicConfig(filename='./1.log/model_v1.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


In [18]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mericpark[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [19]:
# opt 
optimizer = torch.optim.AdamW(swin_samrs.parameters(), lr=1e-5)

# model 
swin_samrs = swin_samrs.train()
swin_samrs = swin_samrs.to("cuda:0")

# dataset 
batch_size= 4 
segdataset = custom_dataset.SegDataset(img_path, mask_path, 512)
dataloader = torch.utils.data.DataLoader(segdataset, batch_size=batch_size, shuffle=True)
#dataloader = fabric.setup_dataloaders(dataloader)

# Define loss
criterion = nn.CrossEntropyLoss()


# run 
epochs = 999
for epoch in range(epochs):

    iteration = 0
    running_loss = 0 
    
    for i, data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}")):

        img, mask = data

        img = img.to("cuda:0")
        mask = mask.to("cuda:0")
        
         
        # opt 
        optimizer.zero_grad()
        
        # run
        outputs = swin_samrs(img)
        
        # criterion
        
        loss = criterion(outputs, mask)
        #print(loss)
        
        loss.backward()
        #fabric.backward(loss)
        optimizer.step()

        # stat
        running_loss += loss.item()

        # log
        logger.info(f"[{epoch}, {i}] loss: {loss:.8f}")

        
        log = {'loss': f'{loss / 10:.8f}' }
        #print(log)
        wandb.log(log)

        log_iter = 200
        if (i % log_iter) == 0:    # print every 2000 mini-batches
            print(f"epoch : {epoch} iter : {i} /  total_iter : {len(dataloader)} running_loss : {running_loss / log_iter}")
            
            running_loss = 0.0
         
            
        #-----
    #-- epoch
    save_path = f"./2.ckpts/swin_rs_{epoch + 1}.pt"
    torch.save(swin_samrs.state_dict(), save_path)


['/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/04.1024_imgs/P0000_0_0.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/04.1024_imgs/P0000_0_1.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/04.1024_imgs/P0000_0_2.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/04.1024_imgs/P0000_1_0.png']
['/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/08.1024_masks_categorized_imgs_all/0.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/08.1024_masks_categorized_imgs_all/1.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/08.1024_masks_categorized_imgs_all/2.png', '/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/08.1024_masks_categorized_imgs_all/3.png']


  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
Epoch 0:   0%|          | 0/21 [00:01<?, ?it/s]
