In [1]:
import torch
import ibot_loader

In [2]:
def plot_tensor_list(tensor_list):
    cmap = ListedColormap([
        '#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
        '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'
    ])
    norm = Normalize(vmin=0, vmax=9)
    
    def plot_with_gray_stripes(ax, tensor):
        array = tensor.numpy() if isinstance(tensor, torch.Tensor) else tensor
        mask = array == -1
        ax.imshow(array, cmap=cmap, norm=norm)
        
        h, w = array.shape
        stripe_pattern = np.ones((h, w, 4))
        stripe_pattern[:, :, :3] = 0.6  # Light gray color
        stripe_pattern[:, :, 3] = 0.1  # Alpha channel for transparency
        
        for i in range(h):
            for j in range(w):
                if mask[i, j] and (i + j) % 2 == 0:
                    stripe_pattern[i, j, 3] = 0.8  # Make the pattern opaque
        
        ax.imshow(stripe_pattern, interpolation='none')
        ax.axis('off')

    for i, tensor in enumerate(tensor_list):
        fig, ax = plt.subplots()
        plot_with_gray_stripes(ax, tensor)
        plt.title(f'Tensor {i+1}')
        plt.show()

In [3]:
class Args:
    arch = 'vit_small'
    patch_size = 4
    window_size = 7
    out_dim = 128
    patch_out_dim = 128
    norm_last_layer = True
    momentum_teacher = 0.996
    use_masked_im_modeling = True
    pred_ratio = [0.3]
    pred_ratio_var = [0]
    pred_shape = 'block'
    pred_start_epoch = 0
    lambda1 = 1.0
    lambda2 = 1.0
    warmup_teacher_temp = 0.04
    teacher_temp = 0.04
    warmup_teacher_patch_temp = 0.04
    teacher_patch_temp = 0.07
    warmup_teacher_temp_epochs = 30
    use_fp16 = True
    weight_decay = 0.04
    weight_decay_end = 0.4
    clip_grad = 3.0
    batch_size_per_gpu = 128
    epochs = 100
    freeze_last_layer = 1
    lr = 0.0005
    warmup_epochs = 10
    min_lr = 1e-6
    optimizer = 'adamw'
    load_from = None
    drop_path = 0.1
    global_crops_number = 2
    global_crops_scale = (0.14, 1.0)
    pad_to_32 = True
    local_crops_number = 0
    local_crops_scale = (0.05, 0.4)
    output_dir = "trained_models/"
    saveckp_freq = 40
    seed = 0
    num_workers = 1
    dist_url = "env://"
    local_rank = 0

args = Args()
torch.set_printoptions(threshold=float("inf"))

In [4]:
dataset = ibot_loader.get_dataset()
sampler = torch.utils.data.RandomSampler(dataset)
dataset_wrapper = ibot_loader.iBOT_DatasetWrapper(dataset, args)
data_loader = torch.utils.data.DataLoader(dataset_wrapper, sampler=sampler, collate_fn=ibot_loader.custom_collate_fn)
print(f"Data loaded: there are {len(dataset)} images.")

Data loaded: there are 6581 images.


  return torch.tensor(sample.tolist(), dtype=torch.float32)


In [5]:
imgs, msks = next(iter(data_loader))
images = [img for img in imgs]
masks = [msk for msk in msks]
print(masks)

[tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0