In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1" # is need to train on 'hachiko'

import math
import time
import pandas as pd

import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.distributed import all_reduce, ReduceOp

import numpy as np
from datasets import Dataset
# from torch.utils.data import Dataset
from PIL import Image, ImageFilter, ImageOps

# from funcs import is_main, to_devices, print_msg

In [2]:
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from functools import partial
from vits import archs
from ssit import SSiT

class SSITConfig(PretrainedConfig):
    model_type = "ssit"

    def __init__(
        self,
        **kwargs
    ):
        self.temperature = 0.2
        self.pool_mode = 'max'
        self.saliency_threshold = 0.5
        self.arch = 'ViT-S-p16'
        self.pretrained = True
        self.input_size = 224
        self.mask_ratio = 0.25
        self.epochs = 300
        self.moco_m = 0.99
        self.ss = 10
        self.cl = 1
        
class SSITSegmentation(PreTrainedModel):
    config_class = SSITConfig

    def __init__(self, config):
        super().__init__(config)

        self.moco_m = config.moco_m
        self.ss = config.ss
        self.cl = config.cl
        self.ss_decay = True
        self.epoch = 0
        self.epochs = 0
        
        encoder = partial(
            archs[config.arch],
            pretrained=config.pretrained,
            img_size=config.input_size,
            mask_ratio=config.mask_ratio,
            )
    
        self.model = SSiT(
            encoder,
            dim=256,
            mlp_dim=4096,
            T=config.temperature,
            pool_mode=config.pool_mode,
            saliency_threshold=config.saliency_threshold,
        )

    def adjust_moco_momentum(self, epoch, epochs):
        """Adjust moco momentum based on current epoch"""
        m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / epochs)) * (1. - self.moco_m)
        return m
    
    def adjust_lambda_ss(self, epoch, epochs):
        """Adjust moco momentum based on current epoch"""
        ss = self.ss * 0.5 * (1. + math.cos(math.pi * epoch / epochs))
        return ss

    def save_weights(self, save_path):
        checkpoint = {
            'state_dict': self.model.state_dict()
        }
        # model = self.model.module
    
        torch.save(checkpoint, os.path.join(save_path, 'checkpoint.pt'))
        # torch.save(model, os.path.join(save_path, 'epoch_.pt'))
        print('Saved checkpoint to {}'.format(save_path))
        

    def forward(self, X1, X2, M1, M2, epoch=None, epochs=None, return_loss=True):

        if epoch == None:
            epoch = self.epoch
            epochs = self.epochs
        else:
            self.epoch = epoch
            self.epochs = epochs

        moco_m = self.adjust_moco_momentum(epoch, epochs)
        ss = self.adjust_lambda_ss(epoch, epochs) if self.ss_decay else self.ss

        cl_loss, ss_loss = self.model(X1, X2, M1, M2, moco_m)
        loss = self.cl * cl_loss + ss * ss_loss
        
        return {"loss": loss, "cl_loss": cl_loss, "ss_loss": ss_loss}
        

In [3]:
# SSITSegConfig = SSITConfig()
model = SSITSegmentation.from_pretrained("saved_models/SSIT_unlabled_bs64_100ep")

In [4]:
model.save_weights("saved_models/SSIT_unlabled_bs64_100ep")

Saved checkpoint to saved_models/SSIT_unlabled_bs64_100ep
