In [None]:
import torchvision
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.optim as optim


from torch import nn
import torchvision.models as models

device = 'cuda' if torch.cuda.is_available() else'cpu'

IMG_SIZE = 256
NUM_CLASSES = 3

# data

In [None]:
transform = transforms.Compose([
  # transforms.Resize((img_size, img_size)),
  transforms.CenterCrop(IMG_SIZE),
  transforms.ToTensor(),               # Convert the image to a PyTorch tensor
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize with mean and std
])

target_transform = transforms.Compose(
    [transforms.CenterCrop(IMG_SIZE),
    lambda img: torch.tensor(np.array(img))
    ]
)


In [None]:
data_train = torchvision.datasets.OxfordIIITPet(root='.', split='trainval', target_types=['segmentation'], download=True,
                                          transform=transform,
                                          target_transform=target_transform)
data_test = torchvision.datasets.OxfordIIITPet(root='.', split='test', target_types=['segmentation'], download=True,
                                          transform=transform,
                                          target_transform=target_transform)

Downloading https://thor.robots.ox.ac.uk/pets/images.tar.gz to oxford-iiit-pet/images.tar.gz


100%|██████████| 792M/792M [00:29<00:00, 26.5MB/s]


Extracting oxford-iiit-pet/images.tar.gz to oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/pets/annotations.tar.gz to oxford-iiit-pet/annotations.tar.gz


100%|██████████| 19.2M/19.2M [00:01<00:00, 9.87MB/s]


Extracting oxford-iiit-pet/annotations.tar.gz to oxford-iiit-pet


In [None]:
for img, lab in data_train:
  print(lab.max(), lab.min())

tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtype=torch.uint8) tensor(1, dtype=torch.uint8)
tensor(3, dtyp

# segmentatoin head

In [None]:
!wget https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar

--2024-12-27 09:31:31--  https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.108, 3.163.189.14, 3.163.189.96, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.108|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 112133139 (107M) [application/octet-stream]
Saving to: ‘moco_v2_800ep_pretrain.pth.tar’


2024-12-27 09:31:32 (221 MB/s) - ‘moco_v2_800ep_pretrain.pth.tar’ saved [112133139/112133139]



In [None]:
class ASPP(nn.Module):

    def __init__(self, C, depth, conv=nn.Conv2d, norm=nn.BatchNorm2d, momentum=0.0003, mult=1):
        super(ASPP, self).__init__()
        self._C = C
        self._depth = depth

        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.relu = nn.ReLU(inplace=True)
        self.aspp1 = conv(C, depth, kernel_size=1, stride=1, bias=False)
        self.aspp2 = conv(C, depth, kernel_size=3, stride=1,
                               dilation=int(6*mult), padding=int(6*mult),
                               bias=False)
        self.aspp3 = conv(C, depth, kernel_size=3, stride=1,
                               dilation=int(12*mult), padding=int(12*mult),
                               bias=False)
        self.aspp4 = conv(C, depth, kernel_size=3, stride=1,
                               dilation=int(18*mult), padding=int(18*mult),
                               bias=False)
        self.aspp5 = conv(C, depth, kernel_size=1, stride=1, bias=False)
        self.aspp1_bn = norm(depth, momentum)
        self.aspp2_bn = norm(depth, momentum)
        self.aspp3_bn = norm(depth, momentum)
        self.aspp4_bn = norm(depth, momentum)
        self.aspp5_bn = norm(depth, momentum)
        self.conv2 = conv(depth * 5, depth, kernel_size=1, stride=1,
                               bias=False)
        self.bn2 = norm(depth, momentum)

    def forward(self, x):
        x1 = self.aspp1(x)
        x1 = self.aspp1_bn(x1)
        x1 = self.relu(x1)
        x2 = self.aspp2(x)
        x2 = self.aspp2_bn(x2)
        x2 = self.relu(x2)
        x3 = self.aspp3(x)
        x3 = self.aspp3_bn(x3)
        x3 = self.relu(x3)
        x4 = self.aspp4(x)
        x4 = self.aspp4_bn(x4)
        x4 = self.relu(x4)
        x5 = self.global_pooling(x)
        x5 = self.aspp5(x5)
        x5 = self.aspp5_bn(x5)
        x5 = self.relu(x5)
        x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
                         align_corners=True)(x5)
        x = torch.cat((x1, x2, x3, x4, x5), 1)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

In [None]:
class SegmentationHead(nn.Module):
  def __init__(self, num_classes, dim_struct=2048, dim_out=256) -> None:
      super().__init__()

      self.aspp = ASPP(2048, 256)
      self.low_level_feature_reducer = nn.Sequential(
          nn.Conv2d(256, 48, 1),
          nn.BatchNorm2d(48, momentum=0.0003),
          nn.ReLU(),
      )
      self.decoder = nn.Sequential(
          nn.Conv2d(256 + 48, 256, 3, padding=1),
          nn.BatchNorm2d(256, momentum=0.0003),
          nn.ReLU(),
          nn.Conv2d(256, 256, 3, padding=1),
          nn.BatchNorm2d(256, momentum=0.0003),
          nn.ReLU(),
          nn.Conv2d(256, num_classes, 3, padding=1),
      )

      self.proj1 = nn.Conv2d(dim_struct, 2048, 3, 1, 1)
      self.proj2 = nn.Conv2d(dim_out, 256, 3, 1, 1)

  def forward(self, features):

    layer1_size = features['struct'].shape[-2:]
    label_size = features["img"].shape[-2:]

    x_aspp = self.aspp(self.proj1(features['out']))
    x_aspp = nn.Upsample(layer1_size, mode='bilinear', align_corners=True)(x_aspp)

    x = torch.cat((self.low_level_feature_reducer(self.proj2(features['struct'])), x_aspp), dim=1)
    x = self.decoder(x)
    x = nn.Upsample(label_size, mode='bilinear', align_corners=True)(x)
    return x


# DDAE encoder

In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from pathlib import Path
import random
import numpy as np
from sklearn.metrics import accuracy_score
import time
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter


In [None]:
def set_deterministic(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def readable_number(num):
    num_str = str(num)[::-1]
    res = ''
    i_prev = 0
    for i in range(3, len(num_str), 3):
        res += num_str[i_prev:i] + ','
        i_prev = i
    if i_prev < len(num_str):
        res += num_str[i_prev:]
    return res[::-1]

def log(writer, metrics, epoch):
    writer.add_scalars('loss', {'train': metrics['loss_train'], 'test': metrics['loss_test']}, epoch)
    writer.add_scalars('accuracy', {'train': metrics['accuracy_train'], 'test': metrics['accuracy_test']}, epoch)
    writer.flush()

def save_checkpoint(state, path, epoch, test_loss):
    Path(path).mkdir(parents=True, exist_ok=True)
    torch.save(state, f'{path}/{epoch}_valloss={test_loss:.3f}.pt')

def get_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return total, trainable

def print_parameters(model):
    total, trainable = get_parameters(model)
    print(f'model initialized with trainable params: {readable_number(trainable)} || total params: {readable_number(total)} || trainable%: {trainable/total * 100:.3f}')

In [None]:
class DiffusionEncoder(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.unet = unet
        self.features = {}

    def forward(self, imgs, timestep, class_labels=None, up_last=-1):
        params = 0
        # 0. center input if necessary
        if self.unet.config.center_input_sample:
            imgs = 2 * imgs - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=imgs.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(imgs.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps * torch.ones(imgs.shape[0], dtype=timesteps.dtype, device=timesteps.device)

        t_emb = self.unet.time_proj(timesteps)

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.unet.dtype)
        emb = self.unet.time_embedding(t_emb)

        total = get_parameters(self.unet.time_embedding)[0]
        params += total
        # print(f'time_embedding {total}')

        if self.unet.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when doing class conditioning")

            if self.unet.config.class_embed_type == "timestep":
                class_labels = self.unet.time_proj(class_labels)

            class_emb = self.unet.class_embedding(class_labels).to(dtype=self.unet.dtype)
            emb = emb + class_emb

            total = get_parameters(self.unet.class_embedding)[0]
            params += total
            # print(f'time_embedding {total}')
        elif self.unet.class_embedding is None and class_labels is not None:
            raise ValueError("class_embedding needs to be initialized in order to use class conditioning")

        # 2. pre-process
        skip_sample = imgs
        imgs = self.unet.conv_in(imgs)

        total = get_parameters(self.unet.conv_in)[0]
        params += total
        # print(f'conv_in {total}')


        # 3. down
        down_block_res_samples = (imgs,)
        for downsample_block in self.unet.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                imgs, res_samples, skip_sample = downsample_block(
                    hidden_states=imgs, temb=emb, skip_sample=skip_sample
                )
            else:
                imgs, res_samples = downsample_block(hidden_states=imgs, temb=emb)

            down_block_res_samples += res_samples

            total = get_parameters(downsample_block)[0]
            params += total
            # print(f'downsample_block {total}')

        # 4. mid
        imgs = self.unet.mid_block(imgs, emb)

        total = get_parameters(self.unet.mid_block)[0]
        params += total

        self.features['out'] = imgs
        # print(f'mid_block {total}')

        # 5. up
        skip_sample = None
        for i, upsample_block in enumerate(self.unet.up_blocks):
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            if hasattr(upsample_block, "skip_conv"):
                imgs, skip_sample = upsample_block(imgs, res_samples, emb, skip_sample)
            else:
                imgs = upsample_block(imgs, res_samples, emb)


            total = get_parameters(upsample_block)[0]
            params += total
            # print(f'upsample_block {total}')

            if up_last == i:
                # print(f'params used = {readable_number(params)}')
                self.features['struct'] = imgs
                return imgs.mean(dim=[2, 3])

        # # 6. post-process
        # imgs = self.unet.conv_norm_out(imgs)
        # imgs = self.unet.conv_act(imgs)
        # imgs = self.unet.conv_out(imgs)

        # if skip_sample is not None:
        #     imgs += skip_sample

        # if self.unet.config.time_embedding_type == "fourier":
        #     timesteps = timesteps.reshape((imgs.shape[0], *([1] * len(imgs.shape[1:]))))
        #     imgs = imgs / timesteps

        return imgs


In [None]:
from diffusers import UNet2DModel

hf_unet = UNet2DModel.from_pretrained("google/ddpm-cat-256")
# hf_unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32")

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/790 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/455M [00:00<?, ?B/s]

In [None]:
img, t = torch.randn(3, 3, 224, 224), torch.tensor([0])

In [None]:
encoder = DiffusionEncoder(hf_unet)

out = encoder(img, t, up_last=2)
out.shape

torch.Size([3, 256])

In [None]:
encoder.features['out'].shape

torch.Size([3, 512, 7, 7])

In [None]:
encoder.features['struct'].shape

torch.Size([3, 256, 56, 56])

# segmentator models

In [None]:
class ResnetSegmentator(nn.Module):
  def __init__(self, encoder, segmentator):
    super().__init__()
    self.encoder = encoder
    self.segmentator = segmentator

    self.features = {}

    def hook1(module, input, output):
      self.features['struct'] = output.clone()
    def hook4(module, input, output):
      self.features['out'] = output.clone()

    self.hooks = [self.encoder.layer1.register_forward_hook(hook1),
                  self.encoder.layer4.register_forward_hook(hook4)]

  def forward(self, x):
    self.features['img'] = x
    self.encoder(x)
    return self.segmentator(self.features)



In [None]:
model = models.resnet50(weights=None)
sh = SegmentationHead(13)
resseg = ResnetSegmentator(model, sh)
dummy_input = torch.randn(2, 3, IMG_SIZE, IMG_SIZE)
resseg(dummy_input).shape

torch.Size([2, 13, 256, 256])

In [None]:
# class DDAESegmentator(nn.Module):
#   def __init__(self, encoder, segmentator):
#     super().__init__()
#     self.encoder = encoder
#     self.segmentator = segmentator


#   def forward(self, x):
#     encoder.features['img'] = x
#     t = torch.zeros(x.shape[0]).to(x.device)
#     encoder(x, t, up_last=2)
#     return self.segmentator(encoder.features)



In [None]:
class DDAESegmentator(nn.Module):
  def __init__(self, encoder, segmentator, noise_scheduler):
    super().__init__()
    self.encoder = encoder
    self.segmentator = segmentator
    self.noise_scheduler = noise_scheduler


  def forward(self, x):
    encoder.features['img'] = x
    # t = torch.zeros(x.shape[0]).to(x.device)
    t = torch.randint(0,50,size=(x.shape[0],)).to(x.device)
    noise = torch.randn_like(x)
    x = self.noise_scheduler.add_noise(x, noise, t)
    encoder(x, t, up_last=2)
    return self.segmentator(encoder.features)



In [None]:
from diffusers import UNet2DModel, DDPMPipeline
pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256")
hf_unet, hf_scheduler = pipe.unet, pipe.scheduler
# hf_unet = UNet2DModel.from_pretrained("google/ddpm-cat-256")
encoder = DiffusionEncoder(pipe)

model_index.json:   0%|          | 0.00/167 [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

scheduler_config.json:   0%|          | 0.00/256 [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/455M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
encoder = DiffusionEncoder(hf_unet)
sh = SegmentationHead(13, 512, 256)
resseg = DDAESegmentator(encoder,sh, hf_scheduler)
dummy_input = torch.randn(2, 3, IMG_SIZE, IMG_SIZE)
resseg(dummy_input).shape

torch.Size([2, 13, 256, 256])

# training loop

In [None]:
def load_mocov2(model, pretrained):
  checkpoint = torch.load(pretrained, map_location="cpu")

  state_dict = checkpoint["state_dict"]
  for k in list(state_dict.keys()):
      if k.startswith("module.encoder_q") and not k.startswith(
          "module.encoder_q.fc"
      ):
          state_dict[k[len("module.encoder_q.") :]] = state_dict[k]
      del state_dict[k]

  msg = model.load_state_dict(state_dict, strict=False)
  assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
  print("=> loaded pre-trained model '{}'".format(pretrained))
  model.fc=nn.Identity()


In [None]:
def save_checkpoint(model, optimizer, filename='checkpoint.pt'):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, filename)


def load_checkpoint(model, optimizer=None, filename='checkpoint.pt'):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
      optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


In [None]:
LOG_IMGS = [0,19, 28, 576, 1032]
log_batch = torch.cat([data_test[ind][0].unsqueeze(0) for ind in LOG_IMGS]).to(device)
log_batch.shape

torch.Size([5, 3, 256, 256])

In [None]:
color_dict = {
    0: (255, 87, 51),   # Red-Orange
    1: (51, 255, 87),   # Green
    2: (51, 87, 255),   # Blue
    3: (255, 51, 161),  # Pink
}
def get_color_labels(labels):
  return np.array([color_dict[p.item()] for
                   p in labels.flatten()]).reshape(IMG_SIZE, IMG_SIZE, 3).astype(np.uint8)

In [None]:
def init_image_log():
  wandb.log({'images': [wandb.Image(data_test[ind][0]) for ind in LOG_IMGS],
             'labels': [wandb.Image(get_color_labels(data_test[ind][1])) for ind in LOG_IMGS] })

In [None]:
@torch.no_grad()
def get_log_preds(model):
  images = []
  with torch.amp.autocast(device_type='cuda'):
    preds = model(log_batch).argmax(1).detach().cpu()
  for pred in preds:
    images.append(wandb.Image(get_color_labels(pred)))

  return images


In [None]:
from tqdm.notebook import tqdm
import wandb

def iou(out, labels):
  ious = []
  prediction = torch.argmax(out, dim=1)
  for cat in torch.unique(labels):
      cat = int(cat)
      if cat == 0:
          continue
      intersection = torch.logical_and(prediction == cat, labels == cat).sum(-1).sum(-1)
      union = torch.logical_or(prediction == cat, labels == cat).sum(-1).sum(-1)
      ious.append(torch.mean((intersection + 1e-8) / (union + 1e-8)).item())
  return ious


class Trainer:
  def __init__(self, model, optimizer, criterion) -> None:
    self.model = model
    self.criterion = criterion
    self.optimizer = optimizer
    self.scaler = torch.amp.GradScaler()


  def train_step(self, imgs, labels):
    with torch.amp.autocast(device_type='cuda'):
      preds = self.model(imgs)
      self.optimizer.zero_grad()
      loss = self.criterion(preds, labels)

      self.scaler.scale(loss).backward()
      self.scaler.step(optimizer)
      self.scaler.update()

    return loss

  def train_loop(self, n_epochs, loader, test_loader,project='ssl_proj', log_every=10,
                 ckpt_path='.',
                 run_name=None, run_id=None, resume = 0):

    iter = resume*len(loader)
    if resume> 0:
      load_checkpoint(model, optimizer, f'{ckpt_path}/{run_name}_ckpt{resume}.pt')
    model.train()

    with wandb.init(project=project, name=run_name, id=run_id, resume='allow'):
      if resume==0:
        init_image_log()

      for epoch in tqdm(range(resume, n_epochs), desc='epoch'):
        for img, label in tqdm(loader, desc='training'):
          img = img.to(device)
          label = label.to(device).long()
          loss = self.train_step(img, label)

          if iter%log_every ==0 :
            m = self.evaluate(test_loader)
            wandb.log({'loss': loss, 'iou': m,'epoch': epoch, 'iter': iter,
                       'preds': get_log_preds(self.model)})
            model.train()

          iter+=1

        if (epoch+1) % 10 == 0:
            path = f'{ckpt_path}/{run_name}_ckpt{epoch+1}.pt'
            save_checkpoint(self.model, self.optimizer, path)

  def evaluate(self, loader):
    ious = []
    model.eval()
    with torch.amp.autocast(device_type='cuda'):
      with torch.no_grad():
        for img, label in tqdm(loader, desc='evaluating'):
          img = img.to(device)
          label = label.to(device).long()
          pred = self.model(img)
          ious += iou(pred, label)

    return np.mean(ious)




##train diffusion

In [None]:
from diffusers import UNet2DModel, DDPMPipeline
pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256")
hf_unet, hf_scheduler = pipe.unet, pipe.scheduler

Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
encoder = DiffusionEncoder(hf_unet)

In [None]:
segmentation_head = SegmentationHead(NUM_CLASSES, 512, 256)

In [None]:
model = DDAESegmentator(encoder, segmentation_head, hf_scheduler).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(segmentation_head.parameters(), lr=0.0003)
trainer = Trainer(model, optimizer, criterion)

In [None]:
def get_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return total, trainable

def print_parameters(model):
    total, trainable = get_parameters(model)
    print(f'model initialized with trainable params: {readable_number(trainable)} || total params: {readable_number(total)} || trainable%: {trainable/total * 100:.3f}')

In [None]:
print_parameters(model)

model initialized with trainable params: 140,627,128 || total params: 140,627,128 || trainable%: 100.000


In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(data_train, batch_size=16, shuffle=True, num_workers=2)
test_loader = DataLoader(data_test, batch_size=16, shuffle=False, num_workers=2)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
ckpt_path='/content/drive/MyDrive/gpu_shared/ssl/ssl_proj'

In [None]:
trainer.train_loop(100, train_loader, test_loader, ckpt_path=ckpt_path, run_name='diffsuion_segmentator_noise', log_every=len(train_loader))

epoch:   0%|          | 0/100 [00:00<?, ?it/s]

training:   0%|          | 0/230 [00:00<?, ?it/s]

evaluating:   0%|          | 0/230 [00:00<?, ?it/s]

training:   0%|          | 0/230 [00:00<?, ?it/s]

evaluating:   0%|          | 0/230 [00:00<?, ?it/s]

training:   0%|          | 0/230 [00:00<?, ?it/s]

evaluating:   0%|          | 0/230 [00:00<?, ?it/s]

training:   0%|          | 0/230 [00:00<?, ?it/s]

evaluating:   0%|          | 0/230 [00:00<?, ?it/s]

training:   0%|          | 0/230 [00:00<?, ?it/s]

evaluating:   0%|          | 0/230 [00:00<?, ?it/s]

training:   0%|          | 0/230 [00:00<?, ?it/s]

evaluating:   0%|          | 0/230 [00:00<?, ?it/s]

## train baseline

In [None]:
# encoder = models.resnet50(weights=None)
# load_mocov2(encoder, '/content/moco_v2_800ep_pretrain.pth.tar')
import torch
encoder = torch.hub.load('facebookresearch/swav:main', 'resnet50')
encoder.fc=nn.Identity()

Downloading: "https://github.com/facebookresearch/swav/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar" to /root/.cache/torch/hub/checkpoints/swav_800ep_pretrain.pth.tar
100%|██████████| 108M/108M [00:00<00:00, 204MB/s] 


In [None]:
segmentation_head = SegmentationHead(NUM_CLASSES)

In [None]:
model = ResnetSegmentator(encoder, segmentation_head).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0003)
trainer = Trainer(model, optimizer, criterion)

In [None]:
print_parameters(segmentation_head)

model initialized with trainable params: 55,265,461 || total params: 55,265,461 || trainable%: 100.000


In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(data_train, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(data_test, batch_size=128, shuffle=False, num_workers=2)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
ckpt_path='/content/drive/MyDrive/gpu_shared/ssl/ssl_proj'

In [None]:
trainer.train_loop(100, train_loader, test_loader, ckpt_path=ckpt_path, run_name='resnet_segmentator',
                   resume=10, run_id='d6bjxsrc')

  checkpoint = torch.load(filename)
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


epoch:   0%|          | 0/90 [00:00<?, ?it/s]

training:   0%|          | 0/58 [00:00<?, ?it/s]

evaluating:   0%|          | 0/29 [00:00<?, ?it/s]

evaluating:   0%|          | 0/29 [00:00<?, ?it/s]