# Import

In [4]:
from torch import nn
import torch
from models import DrFuse_RS_Model, DrFuse_RS_Trainer
import time
import os
from torch.utils.data import Dataset, DataLoader
from tqdm.autonotebook import tqdm
import skimage.io as io
import earthpy.spatial as es
import numpy as np
import random
import wandb


# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

# Data

In [5]:
now = str(int(time.time()))

EXPERIMENT_CONFIG = dict(
    # window data folder
    DATA_FOLDER = r'D:\CarsSegmentationTroubleShoot\data\multimodal',
    # mac data folder
    # DATA_FOLDER = r'/Users/nhikieu/Documents/PhD/multimodal',
    num_epoch = 100, 
    img_size = 512, 
    num_classes = 6,
    input_channels = 5,
    batch_size = 4,
    checkpoint_pth = os.path.join('checkpoints', now),
    update_optim_size = 32
    )
DEVICE = 'mps'

print(EXPERIMENT_CONFIG['checkpoint_pth'])

checkpoints\1720751784


In [6]:
def rgb_to_class(mask_image):
  class_map = {
    (255,255,255): 0, 
    (0,0,255): 1, 
    (0,255,255): 2, 
    (0,255,0): 3, 
    (255,255,0): 4, 
    (255,0,0): 5
  }

  # Create a 3D numpy array that represents the RGB color of each pixel
  rgb_data = mask_image.reshape(-1, 3)

  # Create a 1D numpy array that represents the class label for each RGB color
  class_labels = np.zeros(rgb_data.shape[0], dtype=np.uint8)
  for rgb, class_label in class_map.items():
      mask = np.all(rgb_data == np.array(rgb), axis=1)
      class_labels[mask] = class_label

  # Reshape the 1D class label array into a 2D class map
  class_data = class_labels.reshape(mask_image.shape[:2])

  return class_data


class VaihingenDataset(Dataset):
  def __init__(self, folder) -> None:
    '''
    folder: data path include both rgb img and gt
    gt: 3 channels map with postfix '_gt'
    '''
    super().__init__()

    self.imgs = []
    self.gts = []

    # get all filenames in the directory
    filenames = os.listdir(os.path.join(folder, 'rgb'))

    for f in tqdm(filenames):
      img = io.imread(os.path.join(folder, 'rgb', f)) / 255.0
     
      ndsm = io.imread(os.path.join(folder, 'ndsm', f.split('.png')[0] + '_ndsm.tif'))
      ndsm = np.expand_dims(ndsm, axis=2)
    
      input_4C = np.dstack((img, ndsm))

      input_4C = torch.tensor(input_4C).permute((2, 0, 1))
      
      gt_path = os.path.join(folder, 'gt', f.split('.png')[0] + '_gt.png')
      gt = io.imread(gt_path)
      gt = rgb_to_class(gt)
      gt = torch.tensor(gt).unsqueeze(0)

      self.imgs.append(input_4C)
      self.gts.append(gt)
    

  def __len__(self):
    return len(self.imgs)
  
  def __getitem__(self, index):
    input_4C = self.imgs[index].float()
    gt = self.gts[index]

    # mask_array = [True, False]
    mask_index = np.random.choice(2, 1, p=[0.5, 0.5])
    mask = torch.tensor(mask_index[0])

    return input_4C, gt, mask

In [7]:
#Note that we are using ground truth in the folder '1CGT'. They are not RGB anymore, but one channel.
training_dataset = VaihingenDataset(os.path.join(EXPERIMENT_CONFIG['DATA_FOLDER'], 'train'))
validate_dataset = VaihingenDataset(os.path.join(EXPERIMENT_CONFIG['DATA_FOLDER'], 'val'))

#Note that we shuffle the data for the training set, but not for the validation set:
train_loader = DataLoader(dataset=training_dataset, batch_size=EXPERIMENT_CONFIG['batch_size'], shuffle=True, pin_memory=True)
validate_loader = DataLoader(dataset=validate_dataset, batch_size=EXPERIMENT_CONFIG['batch_size'], shuffle=True, pin_memory=True)

print(f'Train samples: {len(training_dataset)}')
print(f'Val samples: {len(validate_dataset)}')

100%|██████████| 1620/1620 [03:08<00:00,  8.61it/s]
100%|██████████| 120/120 [00:11<00:00, 10.30it/s]

Train samples: 1620
Val samples: 120





# Train

In [8]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnhikieu[0m ([33mqut_nhi_phd[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
# best pretrained path
ckpt_path = r"D:\drfuse\checkpoints\1720700298\1720724395_loss1.8435"
with wandb.init(project="drfuse", entity="nhikieu", config=EXPERIMENT_CONFIG):
    config = wandb.config
    model_trainer = DrFuse_RS_Trainer(config, pretrain=True, ckpt=ckpt_path)
    model_trainer(train_loader, validate_loader)

[34m[1mwandb[0m: Currently logged in as: [33mnhikieu[0m. Use [1m`wandb login --relogin`[0m to force relogin


  batch = (images.to(device), labels.to(device), torch.tensor(masks).to(device))
  0%|          | 0/150 [03:27<?, ?it/s]
Traceback (most recent call last):
  File "C:\Users\n10300759\AppData\Local\Temp\ipykernel_28720\866226422.py", line 6, in <module>
    model_trainer(train_loader, validate_loader)
  File "d:\drfuse\drfuse_venv\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "d:\drfuse\drfuse_venv\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "d:\drfuse\models\drfuse_rs.py", line 93, in forward
    train_loss = self.training_step(batch, log, self.pretrain)
  File "d:\drfuse\models\drfuse_rs.py", line 150, in training_step
    return train_loss.item()
KeyboardInterrupt


KeyboardInterrupt: 

# Unit Test

In [3]:
dummies = torch.rand(6, 4, 512, 512)
masks = torch.FloatTensor([True, False, True, True, False, False])
model = DrFuse_RS_Model()
results = model(dummies, masks)
print(results['pred_rgb'].shape, results['pred_ndsm'].shape, 
      results['pred_shared'].shape, results['pred_multimodal'].shape,
      results['aux_preds'][0].shape)

# Assuming model is an instance of a class derived from torch.nn.Module
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {total_params}")



torch.Size([6, 6, 512, 512]) torch.Size([6, 6, 512, 512]) torch.Size([6, 6, 512, 512]) torch.Size([6, 6, 512, 512]) torch.Size([6, 6, 512, 512])
Total trainable parameters: 42120024


In [None]:
pairs = torch.FloatTensor([True, False, True, False, True, True])
pairs = pairs.unsqueeze(1)
feat_dummy = torch.randn(6, 512)
test = pairs*feat_dummy
test

In [None]:
for i, (input_4C, gt, mask) in enumerate(validate_loader):
  print(input_4C.shape, gt.shape, mask.shape)
  print(mask)
  break