# Import

In [1]:
from torch import nn
import torch
from models import DrFuse_RS_Model, DrFuse_RS_Trainer
import time
import os
from torch.utils.data import Dataset, DataLoader
from tqdm.autonotebook import tqdm
import skimage.io as io
import earthpy.spatial as es
import numpy as np
import random
import wandb


# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

  from tqdm.autonotebook import tqdm


# Data

In [2]:
now = str(int(time.time()))

EXPERIMENT_CONFIG = dict(
    # window data folder
    DATA_FOLDER = r'D:\CarsSegmentationTroubleShoot\data\multimodal',
    # mac data folder
    # DATA_FOLDER = r'/Users/nhikieu/Documents/PhD/multimodal',
    num_epoch = 100, 
    img_size = 512, 
    num_classes = 6,
    input_channels = 5,
    batch_size = 4,
    checkpoint_pth = os.path.join('checkpoints', now),
    update_optim_size = 32
    )
DEVICE = 'mps'

print(EXPERIMENT_CONFIG['checkpoint_pth'])

checkpoints\1720658579


In [3]:
def rgb_to_class(mask_image):
  class_map = {
    (255,255,255): 0, 
    (0,0,255): 1, 
    (0,255,255): 2, 
    (0,255,0): 3, 
    (255,255,0): 4, 
    (255,0,0): 5
  }

  # Create a 3D numpy array that represents the RGB color of each pixel
  rgb_data = mask_image.reshape(-1, 3)

  # Create a 1D numpy array that represents the class label for each RGB color
  class_labels = np.zeros(rgb_data.shape[0], dtype=np.uint8)
  for rgb, class_label in class_map.items():
      mask = np.all(rgb_data == np.array(rgb), axis=1)
      class_labels[mask] = class_label

  # Reshape the 1D class label array into a 2D class map
  class_data = class_labels.reshape(mask_image.shape[:2])

  return class_data


class VaihingenDataset(Dataset):
  def __init__(self, folder) -> None:
    '''
    folder: data path include both rgb img and gt
    gt: 3 channels map with postfix '_gt'
    '''
    super().__init__()

    self.imgs = []
    self.gts = []

    # get all filenames in the directory
    filenames = os.listdir(os.path.join(folder, 'rgb'))

    for f in tqdm(filenames):
      img = io.imread(os.path.join(folder, 'rgb', f)) / 255.0
     
      ndsm = io.imread(os.path.join(folder, 'ndsm', f.split('.png')[0] + '_ndsm.tif'))
      ndsm = np.expand_dims(ndsm, axis=2)
    
      input_4C = np.dstack((img, ndsm))

      input_4C = torch.tensor(input_4C).permute((2, 0, 1))
      
      gt_path = os.path.join(folder, 'gt', f.split('.png')[0] + '_gt.png')
      gt = io.imread(gt_path)
      gt = rgb_to_class(gt)
      gt = torch.tensor(gt).unsqueeze(0)

      self.imgs.append(input_4C)
      self.gts.append(gt)
    

  def __len__(self):
    return len(self.imgs)
  
  def __getitem__(self, index):
    input_4C = self.imgs[index].float()
    gt = self.gts[index]

    # mask_array = [True, False]
    mask_index = np.random.choice(2, 1, p=[0.5, 0.5])
    mask = torch.tensor(mask_index[0])

    return input_4C, gt, mask

In [4]:
#Note that we are using ground truth in the folder '1CGT'. They are not RGB anymore, but one channel.
training_dataset = VaihingenDataset(os.path.join(EXPERIMENT_CONFIG['DATA_FOLDER'], 'train'))
validate_dataset = VaihingenDataset(os.path.join(EXPERIMENT_CONFIG['DATA_FOLDER'], 'val'))

#Note that we shuffle the data for the training set, but not for the validation set:
train_loader = DataLoader(dataset=training_dataset, batch_size=EXPERIMENT_CONFIG['batch_size'], shuffle=True, pin_memory=True)
validate_loader = DataLoader(dataset=validate_dataset, batch_size=EXPERIMENT_CONFIG['batch_size'], shuffle=True, pin_memory=True)

print(f'Train samples: {len(training_dataset)}')
print(f'Val samples: {len(validate_dataset)}')

100%|██████████| 1620/1620 [02:40<00:00, 10.08it/s]
100%|██████████| 120/120 [00:06<00:00, 17.37it/s]

Train samples: 1620
Val samples: 120





# Train

In [5]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mnhikieu[0m ([33mqut_nhi_phd[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
ckpt_path = r"D:\drfuse\checkpoints\1720658579\1720683639_loss2.0603"
with wandb.init(project="drfuse", entity="nhikieu", config=EXPERIMENT_CONFIG):
    config = wandb.config
    model_trainer = DrFuse_RS_Trainer(config, pretrain=True, ckpt_path=ckpt_path)
    model_trainer(train_loader, validate_loader)

[34m[1mwandb[0m: Currently logged in as: [33mnhikieu[0m. Use [1m`wandb login --relogin`[0m to force relogin


  batch = (images.to(device), labels.to(device), torch.tensor(masks).to(device))
  batch = (images.to(device), labels.to(device), torch.tensor(masks).to(device))


Train loss after 00512 examples: 3.786
Val loss after 00512 examples: 4.040
Train loss after 01024 examples: 3.984
Val loss after 01024 examples: 3.988
Train loss after 01536 examples: 4.018
Val loss after 01536 examples: 3.965


  1%|          | 1/100 [03:52<6:23:40, 232.54s/it]

Train loss after 01620 examples: 3.619
Val loss after 01620 examples: 3.964
Train loss after 02048 examples: 3.691
Val loss after 02048 examples: 3.935
Train loss after 02560 examples: 3.723
Val loss after 02560 examples: 3.872
Train loss after 03072 examples: 3.761
Val loss after 03072 examples: 3.824
Train loss after 03240 examples: 3.562
Val loss after 03240 examples: 3.808


  2%|▏         | 2/100 [07:56<6:30:57, 239.37s/it]

Train loss after 03584 examples: 3.410
Val loss after 03584 examples: 3.793
Train loss after 04096 examples: 3.739
Val loss after 04096 examples: 3.778
Train loss after 04608 examples: 3.760
Val loss after 04608 examples: 3.767
Train loss after 04860 examples: 3.655
Val loss after 04860 examples: 3.766


  3%|▎         | 3/100 [12:09<6:37:00, 245.57s/it]

Train loss after 05120 examples: 3.824
Val loss after 05120 examples: 3.766
Train loss after 05632 examples: 3.697
Val loss after 05632 examples: 3.772
Train loss after 06144 examples: 3.432
Val loss after 06144 examples: 3.746
Train loss after 06480 examples: 3.577
Val loss after 06480 examples: 3.727


  4%|▍         | 4/100 [16:21<6:37:11, 248.25s/it]

Train loss after 06656 examples: 3.699
Val loss after 06656 examples: 3.711
Train loss after 07168 examples: 3.716
Val loss after 07168 examples: 3.691
Train loss after 07680 examples: 3.735
Val loss after 07680 examples: 3.669
Train loss after 08100 examples: 3.509
Val loss after 08100 examples: 3.655


  5%|▌         | 5/100 [20:34<6:35:42, 249.92s/it]

Train loss after 08192 examples: 3.722
Val loss after 08192 examples: 3.654
Train loss after 08704 examples: 3.483
Val loss after 08704 examples: 3.633
Train loss after 09216 examples: 3.236
Val loss after 09216 examples: 3.623
Train loss after 09720 examples: 3.353
Val loss after 09720 examples: 3.612


  6%|▌         | 6/100 [24:47<6:32:44, 250.68s/it]

Train loss after 09728 examples: 3.675
Val loss after 09728 examples: 3.611
Train loss after 10240 examples: 3.270
Val loss after 10240 examples: 3.605
Train loss after 10752 examples: 3.346
Val loss after 10752 examples: 3.599
Train loss after 11264 examples: 3.541
Val loss after 11264 examples: 3.596
Train loss after 11340 examples: 3.496
Val loss after 11340 examples: 3.596


  7%|▋         | 7/100 [29:16<6:37:50, 256.67s/it]

Train loss after 11776 examples: 3.528
Val loss after 11776 examples: 3.596
Train loss after 12288 examples: 3.475
Val loss after 12288 examples: 3.596
Train loss after 12800 examples: 3.399
Val loss after 12800 examples: 3.596
Train loss after 12960 examples: 3.582
Val loss after 12960 examples: 3.589


  8%|▊         | 8/100 [33:28<6:31:16, 255.18s/it]

Train loss after 13312 examples: 3.583
Val loss after 13312 examples: 3.568
Train loss after 13824 examples: 3.351
Val loss after 13824 examples: 3.546
Train loss after 14336 examples: 3.342
Val loss after 14336 examples: 3.527
Train loss after 14580 examples: 3.435
Val loss after 14580 examples: 3.515


  9%|▉         | 9/100 [37:40<6:25:57, 254.48s/it]

Train loss after 14848 examples: 3.263
Val loss after 14848 examples: 3.500
Train loss after 15360 examples: 3.546
Val loss after 15360 examples: 3.478
Train loss after 15872 examples: 3.432
Val loss after 15872 examples: 3.455
Train loss after 16200 examples: 3.472
Val loss after 16200 examples: 3.440


 10%|█         | 10/100 [41:53<6:21:02, 254.02s/it]

Train loss after 16384 examples: 3.350
Val loss after 16384 examples: 3.433
Train loss after 16896 examples: 3.262
Val loss after 16896 examples: 3.410
Train loss after 17408 examples: 3.184
Val loss after 17408 examples: 3.393


 11%|█         | 11/100 [46:07<6:16:39, 253.92s/it]

Train loss after 17820 examples: 3.416
Val loss after 17820 examples: 3.396
Train loss after 17920 examples: 3.515
Val loss after 17920 examples: 3.385
Train loss after 18432 examples: 3.235
Val loss after 18432 examples: 3.379
Train loss after 18944 examples: 3.495
Val loss after 18944 examples: 3.346
Train loss after 19440 examples: 3.272
Val loss after 19440 examples: 3.346


 12%|█▏        | 12/100 [50:20<6:11:48, 253.50s/it]

Train loss after 19456 examples: 3.412
Val loss after 19456 examples: 3.344
Train loss after 19968 examples: 3.326
Val loss after 19968 examples: 3.327
Train loss after 20480 examples: 3.124
Val loss after 20480 examples: 3.307
Train loss after 20992 examples: 3.176
Val loss after 20992 examples: 3.307


 13%|█▎        | 13/100 [54:49<6:14:31, 258.29s/it]

Train loss after 21060 examples: 2.921
Val loss after 21060 examples: 3.313
Train loss after 21504 examples: 3.251
Val loss after 21504 examples: 3.289
Train loss after 22016 examples: 2.975
Val loss after 22016 examples: 3.292
Train loss after 22528 examples: 3.116
Val loss after 22528 examples: 3.277
Train loss after 22680 examples: 3.228
Val loss after 22680 examples: 3.269


 14%|█▍        | 14/100 [59:01<6:07:35, 256.46s/it]

Train loss after 23040 examples: 3.014
Val loss after 23040 examples: 3.271
Train loss after 23552 examples: 3.165
Val loss after 23552 examples: 3.275
Train loss after 24064 examples: 3.094
Val loss after 24064 examples: 3.259
Train loss after 24300 examples: 3.143
Val loss after 24300 examples: 3.257


 15%|█▌        | 15/100 [1:03:13<6:01:13, 254.98s/it]

Train loss after 24576 examples: 3.232
Val loss after 24576 examples: 3.257
Train loss after 25088 examples: 3.090
Val loss after 25088 examples: 3.259
Train loss after 25600 examples: 3.307
Val loss after 25600 examples: 3.258


 16%|█▌        | 16/100 [1:07:26<5:56:21, 254.54s/it]

Train loss after 25920 examples: 3.192
Val loss after 25920 examples: 3.257
Train loss after 26112 examples: 3.225
Val loss after 26112 examples: 3.259
Train loss after 26624 examples: 2.998
Val loss after 26624 examples: 3.260
Train loss after 27136 examples: 2.993
Val loss after 27136 examples: 3.303


 17%|█▋        | 17/100 [1:11:38<5:50:55, 253.68s/it]

Train loss after 27540 examples: 3.133
Val loss after 27540 examples: 3.275
Train loss after 27648 examples: 3.368
Val loss after 27648 examples: 3.272
Train loss after 28160 examples: 3.077
Val loss after 28160 examples: 3.233
Train loss after 28672 examples: 3.302
Val loss after 28672 examples: 3.209
Train loss after 29160 examples: 3.145
Val loss after 29160 examples: 3.182


 18%|█▊        | 18/100 [1:15:51<5:46:20, 253.42s/it]

Train loss after 29184 examples: 3.297
Val loss after 29184 examples: 3.193
Train loss after 29696 examples: 3.303
Val loss after 29696 examples: 3.156
Train loss after 30208 examples: 3.221
Val loss after 30208 examples: 3.137
Train loss after 30720 examples: 3.022
Val loss after 30720 examples: 3.121
Train loss after 30780 examples: 2.937
Val loss after 30780 examples: 3.120


 19%|█▉        | 19/100 [1:20:20<5:48:42, 258.30s/it]

Train loss after 31232 examples: 2.975
Val loss after 31232 examples: 3.093
Train loss after 31744 examples: 2.930
Val loss after 31744 examples: 3.071
Train loss after 32256 examples: 2.994
Val loss after 32256 examples: 3.059
Train loss after 32400 examples: 3.006
Val loss after 32400 examples: 3.056


 20%|██        | 20/100 [1:24:34<5:42:31, 256.89s/it]

Train loss after 32768 examples: 3.152
Val loss after 32768 examples: 3.047
Train loss after 33280 examples: 3.121
Val loss after 33280 examples: 3.053
Train loss after 33792 examples: 2.853
Val loss after 33792 examples: 3.009


 21%|██        | 21/100 [1:28:48<5:36:53, 255.87s/it]

Train loss after 34020 examples: 2.896
Val loss after 34020 examples: 3.023
Train loss after 34304 examples: 2.883
Val loss after 34304 examples: 3.020
Train loss after 34816 examples: 2.998
Val loss after 34816 examples: 2.994
Train loss after 35328 examples: 2.801
Val loss after 35328 examples: 2.989
Train loss after 35640 examples: 2.708
Val loss after 35640 examples: 2.952


 22%|██▏       | 22/100 [1:33:00<5:31:15, 254.81s/it]

Train loss after 35840 examples: 2.893
Val loss after 35840 examples: 2.958
Train loss after 36352 examples: 2.548
Val loss after 36352 examples: 2.931
Train loss after 36864 examples: 2.810
Val loss after 36864 examples: 2.929


 23%|██▎       | 23/100 [1:37:10<5:25:20, 253.52s/it]

Train loss after 37260 examples: 2.651
Val loss after 37260 examples: 2.933
Train loss after 37376 examples: 2.631
Val loss after 37376 examples: 2.925
Train loss after 37888 examples: 2.740
Val loss after 37888 examples: 2.910
Train loss after 38400 examples: 2.790
Val loss after 38400 examples: 2.897
Train loss after 38880 examples: 2.707
Val loss after 38880 examples: 2.895


 24%|██▍       | 24/100 [1:41:23<5:20:51, 253.31s/it]

Train loss after 38912 examples: 2.671
Val loss after 38912 examples: 2.895
Train loss after 39424 examples: 2.913
Val loss after 39424 examples: 2.877
Train loss after 39936 examples: 2.490
Val loss after 39936 examples: 2.863
Train loss after 40448 examples: 2.687
Val loss after 40448 examples: 2.856
Train loss after 40500 examples: 2.790
Val loss after 40500 examples: 2.852


 25%|██▌       | 25/100 [1:45:51<5:22:11, 257.76s/it]

Train loss after 40960 examples: 2.519
Val loss after 40960 examples: 2.845
Train loss after 41472 examples: 2.674
Val loss after 41472 examples: 2.842
Train loss after 41984 examples: 2.698
Val loss after 41984 examples: 2.834
Train loss after 42120 examples: 2.727
Val loss after 42120 examples: 2.825


 26%|██▌       | 26/100 [1:50:03<5:15:35, 255.89s/it]

Train loss after 42496 examples: 2.682
Val loss after 42496 examples: 2.841
Train loss after 43008 examples: 2.659
Val loss after 43008 examples: 2.844
Train loss after 43520 examples: 2.626
Val loss after 43520 examples: 2.828
Train loss after 43740 examples: 2.744
Val loss after 43740 examples: 2.823


 27%|██▋       | 27/100 [1:54:15<5:09:55, 254.73s/it]

Train loss after 44032 examples: 2.945
Val loss after 44032 examples: 2.807
Train loss after 44544 examples: 2.770
Val loss after 44544 examples: 2.798
Train loss after 45056 examples: 2.541
Val loss after 45056 examples: 2.803
Train loss after 45360 examples: 2.538
Val loss after 45360 examples: 2.797


 28%|██▊       | 28/100 [1:58:27<5:04:35, 253.82s/it]

Train loss after 45568 examples: 2.607
Val loss after 45568 examples: 2.799
Train loss after 46080 examples: 2.614
Val loss after 46080 examples: 2.801
Train loss after 46592 examples: 2.728
Val loss after 46592 examples: 2.801
Train loss after 46980 examples: 2.344
Val loss after 46980 examples: 2.774


 29%|██▉       | 29/100 [2:02:38<4:59:35, 253.17s/it]

Train loss after 47104 examples: 2.741
Val loss after 47104 examples: 2.776
Train loss after 47616 examples: 2.563
Val loss after 47616 examples: 2.783
Train loss after 48128 examples: 2.764
Val loss after 48128 examples: 2.785


 30%|███       | 30/100 [2:06:50<4:54:41, 252.60s/it]

Train loss after 48600 examples: 2.522
Val loss after 48600 examples: 2.786
Train loss after 48640 examples: 3.001
Val loss after 48640 examples: 2.784
Train loss after 49152 examples: 2.755
Val loss after 49152 examples: 2.778
Train loss after 49664 examples: 2.289
Val loss after 49664 examples: 2.776
Train loss after 50176 examples: 2.750
Val loss after 50176 examples: 2.780


 31%|███       | 31/100 [2:11:17<4:55:38, 257.09s/it]

Train loss after 50220 examples: 2.697
Val loss after 50220 examples: 2.779
Train loss after 50688 examples: 2.805
Val loss after 50688 examples: 2.775
Train loss after 51200 examples: 2.550
Val loss after 51200 examples: 2.770
Train loss after 51712 examples: 2.581
Val loss after 51712 examples: 2.770


 32%|███▏      | 32/100 [2:15:30<4:49:51, 255.75s/it]

Train loss after 51840 examples: 2.560
Val loss after 51840 examples: 2.770
Train loss after 52224 examples: 2.480
Val loss after 52224 examples: 2.772
Train loss after 52736 examples: 2.748
Val loss after 52736 examples: 2.771
Train loss after 53248 examples: 2.994
Val loss after 53248 examples: 2.772


 33%|███▎      | 33/100 [2:19:41<4:44:12, 254.52s/it]

Train loss after 53460 examples: 2.655
Val loss after 53460 examples: 2.773
Train loss after 53760 examples: 2.335
Val loss after 53760 examples: 2.774
Train loss after 54272 examples: 2.296
Val loss after 54272 examples: 2.774
Train loss after 54784 examples: 2.645
Val loss after 54784 examples: 2.774


 34%|███▍      | 34/100 [2:23:52<4:38:42, 253.37s/it]

Train loss after 55080 examples: 2.874
Val loss after 55080 examples: 2.773
Train loss after 55296 examples: 2.722
Val loss after 55296 examples: 2.772
Train loss after 55808 examples: 2.754
Val loss after 55808 examples: 2.940
Train loss after 56320 examples: 2.921
Val loss after 56320 examples: 2.857


 35%|███▌      | 35/100 [2:28:04<4:33:58, 252.91s/it]

Train loss after 56700 examples: 2.533
Val loss after 56700 examples: 2.846
Train loss after 56832 examples: 2.874
Val loss after 56832 examples: 2.829
Train loss after 57344 examples: 2.716
Val loss after 57344 examples: 2.802
Train loss after 57856 examples: 3.066
Val loss after 57856 examples: 2.800


 36%|███▌      | 36/100 [2:32:14<4:29:00, 252.19s/it]

Train loss after 58320 examples: 2.669
Val loss after 58320 examples: 2.784
Train loss after 58368 examples: 2.487
Val loss after 58368 examples: 2.775
Train loss after 58880 examples: 2.620
Val loss after 58880 examples: 2.754
Train loss after 59392 examples: 2.815
Val loss after 59392 examples: 2.739
Train loss after 59904 examples: 2.467
Val loss after 59904 examples: 2.727
Train loss after 59940 examples: 2.623
Val loss after 59940 examples: 2.726


 37%|███▋      | 37/100 [2:36:43<4:29:57, 257.10s/it]

Train loss after 60416 examples: 2.613
Val loss after 60416 examples: 2.743
Train loss after 60928 examples: 2.422
Val loss after 60928 examples: 2.729
Train loss after 61440 examples: 2.436
Val loss after 61440 examples: 2.738


 38%|███▊      | 38/100 [2:40:54<4:23:41, 255.19s/it]

Train loss after 61560 examples: 2.808
Val loss after 61560 examples: 2.812
Train loss after 61952 examples: 2.591
Val loss after 61952 examples: 2.775
Train loss after 62464 examples: 2.475
Val loss after 62464 examples: 2.745
Train loss after 62976 examples: 2.610
Val loss after 62976 examples: 2.732


 39%|███▉      | 39/100 [2:45:05<4:18:12, 253.97s/it]

Train loss after 63180 examples: 2.618
Val loss after 63180 examples: 2.730
Train loss after 63488 examples: 2.562
Val loss after 63488 examples: 2.729
Train loss after 64000 examples: 2.787
Val loss after 64000 examples: 2.719
Train loss after 64512 examples: 2.621
Val loss after 64512 examples: 2.716


 40%|████      | 40/100 [2:49:16<4:13:10, 253.18s/it]

Train loss after 64800 examples: 2.589
Val loss after 64800 examples: 2.719
Train loss after 65024 examples: 2.577
Val loss after 65024 examples: 2.689
Train loss after 65536 examples: 2.386
Val loss after 65536 examples: 2.688
Train loss after 66048 examples: 2.210
Val loss after 66048 examples: 2.676
Train loss after 66420 examples: 2.764
Val loss after 66420 examples: 2.666


 41%|████      | 41/100 [2:53:28<4:08:40, 252.89s/it]

Train loss after 66560 examples: 2.554
Val loss after 66560 examples: 2.720
Train loss after 67072 examples: 2.303
Val loss after 67072 examples: 2.704
Train loss after 67584 examples: 2.556
Val loss after 67584 examples: 2.688
Train loss after 68040 examples: 2.272
Val loss after 68040 examples: 2.639


 42%|████▏     | 42/100 [2:57:40<4:04:01, 252.44s/it]

Train loss after 68096 examples: 2.701
Val loss after 68096 examples: 2.650
Train loss after 68608 examples: 2.614
Val loss after 68608 examples: 2.651
Train loss after 69120 examples: 2.463
Val loss after 69120 examples: 2.664
Train loss after 69632 examples: 2.581
Val loss after 69632 examples: 2.636


 43%|████▎     | 43/100 [3:02:07<4:03:55, 256.76s/it]

Train loss after 69660 examples: 2.332
Val loss after 69660 examples: 2.658
Train loss after 70144 examples: 2.585
Val loss after 70144 examples: 2.614
Train loss after 70656 examples: 2.533
Val loss after 70656 examples: 2.616
Train loss after 71168 examples: 2.570
Val loss after 71168 examples: 2.606


 44%|████▍     | 44/100 [3:06:18<3:58:08, 255.15s/it]

Train loss after 71280 examples: 2.528
Val loss after 71280 examples: 2.629
Train loss after 71680 examples: 2.375
Val loss after 71680 examples: 2.642
Train loss after 72192 examples: 2.461
Val loss after 72192 examples: 2.638
Train loss after 72704 examples: 2.234
Val loss after 72704 examples: 2.611


 45%|████▌     | 45/100 [3:10:30<3:52:53, 254.07s/it]

Train loss after 72900 examples: 2.526
Val loss after 72900 examples: 2.610
Train loss after 73216 examples: 2.281
Val loss after 73216 examples: 2.589
Train loss after 73728 examples: 2.468
Val loss after 73728 examples: 2.593
Train loss after 74240 examples: 2.502
Val loss after 74240 examples: 2.563


 46%|████▌     | 46/100 [3:14:42<3:48:09, 253.50s/it]

Train loss after 74520 examples: 2.453
Val loss after 74520 examples: 2.631
Train loss after 74752 examples: 2.297
Val loss after 74752 examples: 2.552
Train loss after 75264 examples: 2.485
Val loss after 75264 examples: 2.556
Train loss after 75776 examples: 2.219
Val loss after 75776 examples: 2.555
Train loss after 76140 examples: 2.273
Val loss after 76140 examples: 2.548


 47%|████▋     | 47/100 [3:18:54<3:43:35, 253.13s/it]

Train loss after 76288 examples: 2.240
Val loss after 76288 examples: 2.604
Train loss after 76800 examples: 2.573
Val loss after 76800 examples: 2.595
Train loss after 77312 examples: 2.295
Val loss after 77312 examples: 2.546
Train loss after 77760 examples: 2.301
Val loss after 77760 examples: 2.528


 48%|████▊     | 48/100 [3:23:06<3:39:10, 252.90s/it]

Train loss after 77824 examples: 2.321
Val loss after 77824 examples: 2.521
Train loss after 78336 examples: 2.571
Val loss after 78336 examples: 2.547
Train loss after 78848 examples: 2.381
Val loss after 78848 examples: 2.555
Train loss after 79360 examples: 2.264
Val loss after 79360 examples: 2.526


 49%|████▉     | 49/100 [3:27:36<3:39:19, 258.03s/it]

Train loss after 79380 examples: 2.427
Val loss after 79380 examples: 2.549
Train loss after 79872 examples: 2.146
Val loss after 79872 examples: 2.538
Train loss after 80384 examples: 2.371
Val loss after 80384 examples: 2.533
Train loss after 80896 examples: 2.783
Val loss after 80896 examples: 2.529
Train loss after 81000 examples: 2.147
Val loss after 81000 examples: 2.517


 50%|█████     | 50/100 [3:31:49<3:33:33, 256.27s/it]

Train loss after 81408 examples: 2.265
Val loss after 81408 examples: 2.544
Train loss after 81920 examples: 2.416
Val loss after 81920 examples: 2.507
Train loss after 82432 examples: 2.284
Val loss after 82432 examples: 2.494


 51%|█████     | 51/100 [3:36:01<3:28:23, 255.17s/it]

Train loss after 82620 examples: 2.485
Val loss after 82620 examples: 2.507
Train loss after 82944 examples: 2.681
Val loss after 82944 examples: 2.500
Train loss after 83456 examples: 2.365
Val loss after 83456 examples: 2.488
Train loss after 83968 examples: 2.555
Val loss after 83968 examples: 2.487
Train loss after 84240 examples: 2.224
Val loss after 84240 examples: 2.461


 52%|█████▏    | 52/100 [3:40:14<3:23:36, 254.51s/it]

Train loss after 84480 examples: 2.002
Val loss after 84480 examples: 2.489
Train loss after 84992 examples: 2.370
Val loss after 84992 examples: 2.465
Train loss after 85504 examples: 2.148
Val loss after 85504 examples: 2.450


 53%|█████▎    | 53/100 [3:44:28<3:19:07, 254.20s/it]

Train loss after 85860 examples: 2.388
Val loss after 85860 examples: 2.460
Train loss after 86016 examples: 2.270
Val loss after 86016 examples: 2.465
Train loss after 86528 examples: 2.119
Val loss after 86528 examples: 2.454
Train loss after 87040 examples: 2.583
Val loss after 87040 examples: 2.456


 54%|█████▍    | 54/100 [3:48:40<3:14:27, 253.63s/it]

Train loss after 87480 examples: 2.383
Val loss after 87480 examples: 2.531
Train loss after 87552 examples: 2.379
Val loss after 87552 examples: 2.487
Train loss after 88064 examples: 2.379
Val loss after 88064 examples: 2.459
Train loss after 88576 examples: 2.311
Val loss after 88576 examples: 2.441
Train loss after 89088 examples: 2.423
Val loss after 89088 examples: 2.459


 55%|█████▌    | 55/100 [3:53:07<3:13:09, 257.55s/it]

Train loss after 89100 examples: 2.244
Val loss after 89100 examples: 2.521
Train loss after 89600 examples: 2.448
Val loss after 89600 examples: 2.434
Train loss after 90112 examples: 2.445
Val loss after 90112 examples: 2.457
Train loss after 90624 examples: 2.076
Val loss after 90624 examples: 2.426
Train loss after 90720 examples: 1.925
Val loss after 90720 examples: 2.412


 56%|█████▌    | 56/100 [3:57:18<3:07:36, 255.82s/it]

Train loss after 91136 examples: 2.019
Val loss after 91136 examples: 2.445
Train loss after 91648 examples: 1.892
Val loss after 91648 examples: 2.418
Train loss after 92160 examples: 2.342
Val loss after 92160 examples: 2.412


 57%|█████▋    | 57/100 [4:01:30<3:02:31, 254.68s/it]

Train loss after 92340 examples: 2.306
Val loss after 92340 examples: 2.435
Train loss after 92672 examples: 2.359
Val loss after 92672 examples: 2.414
Train loss after 93184 examples: 2.114
Val loss after 93184 examples: 2.421
Train loss after 93696 examples: 2.388
Val loss after 93696 examples: 2.402


 58%|█████▊    | 58/100 [4:05:42<2:57:38, 253.78s/it]

Train loss after 93960 examples: 2.416
Val loss after 93960 examples: 2.424
Train loss after 94208 examples: 2.282
Val loss after 94208 examples: 2.415
Train loss after 94720 examples: 2.177
Val loss after 94720 examples: 2.414
Train loss after 95232 examples: 2.303
Val loss after 95232 examples: 2.390


 59%|█████▉    | 59/100 [4:09:53<2:52:51, 252.97s/it]

Train loss after 95580 examples: 2.217
Val loss after 95580 examples: 2.411
Train loss after 95744 examples: 2.228
Val loss after 95744 examples: 2.392
Train loss after 96256 examples: 1.993
Val loss after 96256 examples: 2.416
Train loss after 96768 examples: 2.266
Val loss after 96768 examples: 2.399


 60%|██████    | 60/100 [4:14:07<2:48:47, 253.18s/it]

Train loss after 97200 examples: 2.246
Val loss after 97200 examples: 2.397
Train loss after 97280 examples: 2.131
Val loss after 97280 examples: 2.397
Train loss after 97792 examples: 2.040
Val loss after 97792 examples: 2.386
Train loss after 98304 examples: 2.356
Val loss after 98304 examples: 2.391
Train loss after 98816 examples: 2.509
Val loss after 98816 examples: 2.392


 61%|██████    | 61/100 [4:18:36<2:47:44, 258.08s/it]

Train loss after 98820 examples: 2.309
Val loss after 98820 examples: 2.395
Train loss after 99328 examples: 2.013
Val loss after 99328 examples: 2.377
Train loss after 99840 examples: 2.253
Val loss after 99840 examples: 2.397
Train loss after 100352 examples: 2.299
Val loss after 100352 examples: 2.386


 62%|██████▏   | 62/100 [4:22:50<2:42:31, 256.62s/it]

Train loss after 100440 examples: 2.395
Val loss after 100440 examples: 2.385
Train loss after 100864 examples: 2.166
Val loss after 100864 examples: 2.379
Train loss after 101376 examples: 2.098
Val loss after 101376 examples: 2.391
Train loss after 101888 examples: 2.200
Val loss after 101888 examples: 2.378


 63%|██████▎   | 63/100 [4:27:02<2:37:23, 255.22s/it]

Train loss after 102060 examples: 2.186
Val loss after 102060 examples: 2.379
Train loss after 102400 examples: 2.287
Val loss after 102400 examples: 2.385
Train loss after 102912 examples: 2.166
Val loss after 102912 examples: 2.384
Train loss after 103424 examples: 2.319
Val loss after 103424 examples: 2.379
Train loss after 103680 examples: 1.821
Val loss after 103680 examples: 2.375


 64%|██████▍   | 64/100 [4:31:14<2:32:36, 254.35s/it]

Train loss after 103936 examples: 2.322
Val loss after 103936 examples: 2.374
Train loss after 104448 examples: 2.174
Val loss after 104448 examples: 2.376
Train loss after 104960 examples: 2.186
Val loss after 104960 examples: 2.384


 65%|██████▌   | 65/100 [4:35:25<2:27:53, 253.54s/it]

Train loss after 105300 examples: 2.313
Val loss after 105300 examples: 2.383
Train loss after 105472 examples: 2.056
Val loss after 105472 examples: 2.381
Train loss after 105984 examples: 2.194
Val loss after 105984 examples: 2.374
Train loss after 106496 examples: 2.240
Val loss after 106496 examples: 2.373


 66%|██████▌   | 66/100 [4:39:38<2:23:29, 253.23s/it]

Train loss after 106920 examples: 2.331
Val loss after 106920 examples: 2.378
Train loss after 107008 examples: 2.269
Val loss after 107008 examples: 2.379
Train loss after 107520 examples: 2.047
Val loss after 107520 examples: 2.377
Train loss after 108032 examples: 2.263
Val loss after 108032 examples: 2.375


 67%|██████▋   | 67/100 [4:43:48<2:18:44, 252.26s/it]

Train loss after 108540 examples: 2.565
Val loss after 108540 examples: 2.374
Train loss after 108544 examples: 2.206
Val loss after 108544 examples: 2.375
Train loss after 109056 examples: 2.188
Val loss after 109056 examples: 2.376
Train loss after 109568 examples: 1.994
Val loss after 109568 examples: 2.374
Train loss after 110080 examples: 2.281
Val loss after 110080 examples: 2.374


 68%|██████▊   | 68/100 [4:48:16<2:17:06, 257.07s/it]

Train loss after 110160 examples: 2.352
Val loss after 110160 examples: 2.374
Train loss after 110592 examples: 2.029
Val loss after 110592 examples: 2.377
Train loss after 111104 examples: 2.243
Val loss after 111104 examples: 2.375
Train loss after 111616 examples: 1.917
Val loss after 111616 examples: 2.375


 69%|██████▉   | 69/100 [4:52:28<2:11:59, 255.46s/it]

Train loss after 111780 examples: 2.093
Val loss after 111780 examples: 2.376
Train loss after 112128 examples: 2.133
Val loss after 112128 examples: 2.374
Train loss after 112640 examples: 2.251
Val loss after 112640 examples: 3.156
Train loss after 113152 examples: 2.322
Val loss after 113152 examples: 2.645


 70%|███████   | 70/100 [4:56:42<2:07:29, 254.99s/it]

Train loss after 113400 examples: 2.319
Val loss after 113400 examples: 2.562
Train loss after 113664 examples: 2.666
Val loss after 113664 examples: 2.614
Train loss after 114176 examples: 2.433
Val loss after 114176 examples: 2.573
Train loss after 114688 examples: 2.674
Val loss after 114688 examples: 2.507


 71%|███████   | 71/100 [5:00:54<2:02:47, 254.06s/it]

Train loss after 115020 examples: 2.372
Val loss after 115020 examples: 2.555
Train loss after 115200 examples: 2.544
Val loss after 115200 examples: 2.520
Train loss after 115712 examples: 2.603
Val loss after 115712 examples: 2.518
Train loss after 116224 examples: 2.152
Val loss after 116224 examples: 2.501


 72%|███████▏  | 72/100 [5:05:05<1:58:07, 253.13s/it]

Train loss after 116640 examples: 2.604
Val loss after 116640 examples: 2.488
Train loss after 116736 examples: 2.284
Val loss after 116736 examples: 2.459
Train loss after 117248 examples: 2.456
Val loss after 117248 examples: 2.469
Train loss after 117760 examples: 2.462
Val loss after 117760 examples: 2.457


 73%|███████▎  | 73/100 [5:09:16<1:53:39, 252.58s/it]

Train loss after 118260 examples: 2.227
Val loss after 118260 examples: 2.437
Train loss after 118272 examples: 2.253
Val loss after 118272 examples: 2.436
Train loss after 118784 examples: 2.559
Val loss after 118784 examples: 2.450
Train loss after 119296 examples: 2.448
Val loss after 119296 examples: 2.428
Train loss after 119808 examples: 2.407
Val loss after 119808 examples: 2.429


 74%|███████▍  | 74/100 [5:13:44<1:51:24, 257.10s/it]

Train loss after 119880 examples: 2.304
Val loss after 119880 examples: 2.424
Train loss after 120320 examples: 2.267
Val loss after 120320 examples: 2.419
Train loss after 120832 examples: 2.175
Val loss after 120832 examples: 2.394
Train loss after 121344 examples: 2.418
Val loss after 121344 examples: 2.390


 75%|███████▌  | 75/100 [5:17:56<1:46:32, 255.72s/it]

Train loss after 121500 examples: 2.355
Val loss after 121500 examples: 2.447
Train loss after 121856 examples: 2.376
Val loss after 121856 examples: 2.412
Train loss after 122368 examples: 2.134
Val loss after 122368 examples: 2.446
Train loss after 122880 examples: 2.440
Val loss after 122880 examples: 2.423


 76%|███████▌  | 76/100 [5:22:07<1:41:45, 254.40s/it]

Train loss after 123120 examples: 2.574
Val loss after 123120 examples: 2.383
Train loss after 123392 examples: 2.499
Val loss after 123392 examples: 2.399
Train loss after 123904 examples: 2.464
Val loss after 123904 examples: 2.407
Train loss after 124416 examples: 2.426
Val loss after 124416 examples: 2.379


 77%|███████▋  | 77/100 [5:26:19<1:37:13, 253.62s/it]

Train loss after 124740 examples: 1.961
Val loss after 124740 examples: 2.394
Train loss after 124928 examples: 2.228
Val loss after 124928 examples: 2.368
Train loss after 125440 examples: 2.259
Val loss after 125440 examples: 2.377
Train loss after 125952 examples: 2.074
Val loss after 125952 examples: 2.354


 78%|███████▊  | 78/100 [5:30:32<1:32:53, 253.36s/it]

Train loss after 126360 examples: 2.195
Val loss after 126360 examples: 2.405
Train loss after 126464 examples: 2.169
Val loss after 126464 examples: 2.378
Train loss after 126976 examples: 2.458
Val loss after 126976 examples: 2.383
Train loss after 127488 examples: 2.194
Val loss after 127488 examples: 2.375


 79%|███████▉  | 79/100 [5:34:44<1:28:33, 253.04s/it]

Train loss after 127980 examples: 2.430
Val loss after 127980 examples: 2.365
Train loss after 128000 examples: 2.286
Val loss after 128000 examples: 2.374
Train loss after 128512 examples: 2.204
Val loss after 128512 examples: 2.333
Train loss after 129024 examples: 2.283
Val loss after 129024 examples: 2.344
Train loss after 129536 examples: 2.159
Val loss after 129536 examples: 2.321


 80%|████████  | 80/100 [5:39:13<1:25:54, 257.71s/it]

Train loss after 129600 examples: 2.366
Val loss after 129600 examples: 2.323
Train loss after 130048 examples: 2.234
Val loss after 130048 examples: 2.335
Train loss after 130560 examples: 2.301
Val loss after 130560 examples: 2.353
Train loss after 131072 examples: 1.953
Val loss after 131072 examples: 2.331
Train loss after 131220 examples: 2.011
Val loss after 131220 examples: 2.300


 81%|████████  | 81/100 [5:43:23<1:20:55, 255.55s/it]

Train loss after 131584 examples: 2.146
Val loss after 131584 examples: 2.323
Train loss after 132096 examples: 2.041
Val loss after 132096 examples: 2.312
Train loss after 132608 examples: 1.871
Val loss after 132608 examples: 2.273
Train loss after 132840 examples: 2.208
Val loss after 132840 examples: 2.271


 82%|████████▏ | 82/100 [5:47:35<1:16:19, 254.43s/it]

Train loss after 133120 examples: 2.048
Val loss after 133120 examples: 2.279
Train loss after 133632 examples: 2.010
Val loss after 133632 examples: 2.278
Train loss after 134144 examples: 2.222
Val loss after 134144 examples: 2.315


 83%|████████▎ | 83/100 [5:51:47<1:11:53, 253.73s/it]

Train loss after 134460 examples: 2.143
Val loss after 134460 examples: 2.286
Train loss after 134656 examples: 1.932
Val loss after 134656 examples: 2.350
Train loss after 135168 examples: 2.347
Val loss after 135168 examples: 2.238
Train loss after 135680 examples: 2.142
Val loss after 135680 examples: 2.230


 84%|████████▍ | 84/100 [5:56:00<1:07:32, 253.28s/it]

Train loss after 136080 examples: 2.322
Val loss after 136080 examples: 2.308
Train loss after 136192 examples: 2.286
Val loss after 136192 examples: 2.310
Train loss after 136704 examples: 2.143
Val loss after 136704 examples: 2.278
Train loss after 137216 examples: 1.839
Val loss after 137216 examples: 2.259


 85%|████████▌ | 85/100 [6:00:12<1:03:14, 252.99s/it]

Train loss after 137700 examples: 2.076
Val loss after 137700 examples: 2.286
Train loss after 137728 examples: 2.035
Val loss after 137728 examples: 2.292
Train loss after 138240 examples: 2.349
Val loss after 138240 examples: 2.267
Train loss after 138752 examples: 1.898
Val loss after 138752 examples: 2.253
Train loss after 139264 examples: 2.152
Val loss after 139264 examples: 2.247


 86%|████████▌ | 86/100 [6:04:41<1:00:07, 257.70s/it]

Train loss after 139320 examples: 1.990
Val loss after 139320 examples: 2.256
Train loss after 139776 examples: 2.167
Val loss after 139776 examples: 2.324
Train loss after 140288 examples: 2.087
Val loss after 140288 examples: 2.253
Train loss after 140800 examples: 2.137
Val loss after 140800 examples: 2.254
Train loss after 140940 examples: 2.081
Val loss after 140940 examples: 2.225


 87%|████████▋ | 87/100 [6:08:54<55:34, 256.51s/it]  

Train loss after 141312 examples: 1.933
Val loss after 141312 examples: 2.196
Train loss after 141824 examples: 2.240
Val loss after 141824 examples: 2.200
Train loss after 142336 examples: 1.943
Val loss after 142336 examples: 2.183


 88%|████████▊ | 88/100 [6:13:01<50:42, 253.58s/it]

Train loss after 142560 examples: 1.926
Val loss after 142560 examples: 2.189
Train loss after 142848 examples: 2.231
Val loss after 142848 examples: 2.188
Train loss after 143360 examples: 2.148
Val loss after 143360 examples: 2.167
Train loss after 143872 examples: 2.197
Val loss after 143872 examples: 2.209


 89%|████████▉ | 89/100 [6:16:55<45:25, 247.80s/it]

Train loss after 144180 examples: 2.013
Val loss after 144180 examples: 2.203
Train loss after 144384 examples: 1.925
Val loss after 144384 examples: 2.219
Train loss after 144896 examples: 2.154
Val loss after 144896 examples: 2.198
Train loss after 145408 examples: 2.166
Val loss after 145408 examples: 2.213


 90%|█████████ | 90/100 [6:20:50<40:37, 243.78s/it]

Train loss after 145800 examples: 2.185
Val loss after 145800 examples: 2.310
Train loss after 145920 examples: 2.046
Val loss after 145920 examples: 2.153
Train loss after 146432 examples: 2.096
Val loss after 146432 examples: 2.168
Train loss after 146944 examples: 2.226
Val loss after 146944 examples: 2.137


 91%|█████████ | 91/100 [6:24:44<36:08, 240.96s/it]

Train loss after 147420 examples: 1.991
Val loss after 147420 examples: 2.153
Train loss after 147456 examples: 2.122
Val loss after 147456 examples: 2.152
Train loss after 147968 examples: 1.978
Val loss after 147968 examples: 2.155
Train loss after 148480 examples: 1.832
Val loss after 148480 examples: 2.148
Train loss after 148992 examples: 2.020
Val loss after 148992 examples: 2.135


 92%|█████████▏| 92/100 [6:28:52<32:23, 242.91s/it]

Train loss after 149040 examples: 1.723
Val loss after 149040 examples: 2.150
Train loss after 149504 examples: 2.016
Val loss after 149504 examples: 2.130
Train loss after 150016 examples: 1.900
Val loss after 150016 examples: 2.111
Train loss after 150528 examples: 1.843
Val loss after 150528 examples: 2.113


 93%|█████████▎| 93/100 [6:32:45<28:01, 240.17s/it]

Train loss after 150660 examples: 2.171
Val loss after 150660 examples: 2.295
Train loss after 151040 examples: 2.039
Val loss after 151040 examples: 2.198
Train loss after 151552 examples: 2.135
Val loss after 151552 examples: 2.134
Train loss after 152064 examples: 2.196
Val loss after 152064 examples: 2.138


 94%|█████████▍| 94/100 [6:36:39<23:49, 238.28s/it]

Train loss after 152280 examples: 1.910
Val loss after 152280 examples: 2.203
Train loss after 152576 examples: 2.047
Val loss after 152576 examples: 2.134
Train loss after 153088 examples: 2.069
Val loss after 153088 examples: 2.121
Train loss after 153600 examples: 2.161
Val loss after 153600 examples: 2.134
Train loss after 153900 examples: 1.747
Val loss after 153900 examples: 2.104


 95%|█████████▌| 95/100 [6:40:33<19:45, 237.04s/it]

Train loss after 154112 examples: 1.788
Val loss after 154112 examples: 2.102
Train loss after 154624 examples: 1.845
Val loss after 154624 examples: 2.091
Train loss after 155136 examples: 1.926
Val loss after 155136 examples: 2.120
Train loss after 155520 examples: 1.857
Val loss after 155520 examples: 2.087


 96%|█████████▌| 96/100 [6:44:28<15:45, 236.35s/it]

Train loss after 155648 examples: 1.952
Val loss after 155648 examples: 2.078
Train loss after 156160 examples: 2.116
Val loss after 156160 examples: 2.108
Train loss after 156672 examples: 2.355
Val loss after 156672 examples: 2.082


 97%|█████████▋| 97/100 [6:48:22<11:46, 235.63s/it]

Train loss after 157140 examples: 2.029
Val loss after 157140 examples: 2.113
Train loss after 157184 examples: 1.818
Val loss after 157184 examples: 2.102
Train loss after 157696 examples: 2.004
Val loss after 157696 examples: 2.068
Train loss after 158208 examples: 1.964
Val loss after 158208 examples: 2.089
Train loss after 158720 examples: 2.142
Val loss after 158720 examples: 2.070


 98%|█████████▊| 98/100 [6:52:30<07:58, 239.27s/it]

Train loss after 158760 examples: 1.893
Val loss after 158760 examples: 2.086
Train loss after 159232 examples: 1.804
Val loss after 159232 examples: 2.062
Train loss after 159744 examples: 1.897
Val loss after 159744 examples: 2.060
Train loss after 160256 examples: 2.009
Val loss after 160256 examples: 2.064


 99%|█████████▉| 99/100 [6:56:24<03:57, 237.84s/it]

Train loss after 160380 examples: 1.781
Val loss after 160380 examples: 2.081
Train loss after 160768 examples: 2.080
Val loss after 160768 examples: 2.103
Train loss after 161280 examples: 1.876
Val loss after 161280 examples: 2.080
Train loss after 161792 examples: 1.902
Val loss after 161792 examples: 2.080


100%|██████████| 100/100 [7:00:18<00:00, 252.19s/it]

Train loss after 162000 examples: 1.779
Val loss after 162000 examples: 2.098





0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train_loss,██▇▇▆▅▆▇▅▄▅▃▅▄▆▄▃▄▃▃▄▃▃▃▃▂▂▂▄▃▂▃▁▁▂▂▂▂▁▁
val_loss,█▇▇▇▆▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▃▂▂▂▂▂▂▁▁▁▁▁

0,1
epoch,99.0
train_loss,1.77871
val_loss,2.09775


# Unit Test

In [2]:
dummies = torch.rand(6, 4, 512, 512)
masks = torch.FloatTensor([True, False, True, True, False, False])
model = DrFuse_RS_Model()
results = model(dummies, masks)
print(results['pred_rgb'].shape, results['pred_ndsm'].shape, 
      results['pred_shared'].shape, results['pred_multimodal'].shape,
      results['aux_preds'][0].shape)

# Assuming model is an instance of a class derived from torch.nn.Module
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {total_params}")



torch.Size([6, 6, 512, 512]) torch.Size([6, 6, 512, 512]) torch.Size([6, 6, 512, 512]) torch.Size([6, 6, 512, 512]) torch.Size([6, 6, 512, 512])


In [None]:
pairs = torch.FloatTensor([True, False, True, False, True, True])
pairs = pairs.unsqueeze(1)
feat_dummy = torch.randn(6, 512)
test = pairs*feat_dummy
test

In [None]:
for i, (input_4C, gt, mask) in enumerate(validate_loader):
  print(input_4C.shape, gt.shape, mask.shape)
  print(mask)
  break