In [4]:
from argparse import ArgumentParser

parser = ArgumentParser(add_help=False)
parser.add_argument(
    "--data_path", type=str, help="path where dataset is stored", default="../datasets"
)
parser.add_argument(
    "--accumulate_grad_batches", type=int, help="", default=2
)
parser.add_argument(
    "--max_epochs", type=int, help="max epochs", default=240
)
parser.add_argument(
    "--dataset",
    default="ade20k",
    help="dataset to train on",
)
parser.add_argument(
    "--batch_size", type=int, default=1, help="size of the batches"
)
parser.add_argument(
    "--base_lr", type=float, default=0.004, help="learning rate"
)
parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum")
parser.add_argument(
    "--weight_decay", type=float, default=1e-4, help="weight_decay"
)
parser.add_argument(
    "--aux", action="store_true", default=False, help="Auxilary Loss"
)
parser.add_argument(
    "--aux-weight",
    type=float,
    default=0.2,
    help="Auxilary loss weight (default: 0.2)",
)
parser.add_argument(
    "--se-loss",
    action="store_true",
    default=False,
    help="Semantic Encoding Loss SE-loss",
)
parser.add_argument(
    "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
)
parser.add_argument(
    "--midasproto", action="store_true", default=False, help="midasprotocol"
)
parser.add_argument(
    "--ignore_index",
    type=int,
    default=-1,
    help="numeric value of ignore label in gt",
)
parser.add_argument(
    "--augment",
    action="store_true",
    default=False,
    help="Use extended augmentations",
)
parser.add_argument(
    "--backbone",
    type=str,
    default="clip_vitl16_384",
    help="backbone network",
)
parser.add_argument(
    "--num_features",
    type=int,
    default=256,
    help="number of featurs that go from encoder to decoder",
)
parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate")
parser.add_argument(
    "--finetune_weights", type=str, help="load weights to finetune from"
)
parser.add_argument(
    "--no-scaleinv",
    default=True,
    action="store_false",
    help="turn off scaleinv layers",
)
parser.add_argument(
    "--no-batchnorm",
    default=False,
    action="store_true",
    help="turn off batchnorm",
)
parser.add_argument(
    "--widehead", default=False, action="store_true", help="wider output head"
)
parser.add_argument(
    "--widehead_hr",
    default=False,
    action="store_true",
    help="wider output head",
)
parser.add_argument(
    "--arch_option",
    type=int,
    default=0,
    help="which kind of architecture to be used",
)
parser.add_argument(
    "--block_depth",
    type=int,
    default=0,
    help="how many blocks should be used",
)
parser.add_argument(
    "--activation",
    choices=['lrelu', 'tanh'],
    default="lrelu",
    help="use which activation to activate the block",
)

args = parser.parse_args("")
args.exp_name = "lseg_ade20k_l16"

args

Namespace(data_path='../datasets', accumulate_grad_batches=2, max_epochs=240, dataset='ade20k', batch_size=1, base_lr=0.004, momentum=0.9, weight_decay=0.0001, aux=False, aux_weight=0.2, se_loss=False, se_weight=0.2, midasproto=False, ignore_index=-1, augment=False, backbone='clip_vitl16_384', num_features=256, dropout=0.1, finetune_weights=None, no_scaleinv=True, no_batchnorm=False, widehead=False, widehead_hr=False, arch_option=0, block_depth=0, activation='lrelu', exp_name='lseg_ade20k_l16')

In [None]:
# do_training(args, LSegModule)
from modules.lseg_module import LSegModule
import pytorch_lightning as pl

from utils import make_checkpoint_callbacks, get_wandb_logger

checkpoint = "./checkpoints/demo_e200.ckpt"
args.lr = 0.00001

lseg = LSegModule.load_from_checkpoint(checkpoint, **vars(args))

# set all sorts of training parameters
args.gpus = -1
args.accelerator = "ddp"
args.benchmark = True

args.version = 0

args.sync_batchnorm = True

ttlogger = pl.loggers.TestTubeLogger(
    "checkpoints", name=args.exp_name, version=args.version
)

args.callbacks = make_checkpoint_callbacks(args.exp_name, args.version)

wblogger = get_wandb_logger(args)
args.logger = [wblogger, ttlogger]

trainer = pl.Trainer.from_argparse_args(args)

# only train on a subset of data during dev
trainer.limit_train_batches = 0.001
trainer.limit_val_batches = 0.001

In [None]:
import open_clip
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def load_biomed_clip(device):
    model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
    tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
    model.to(device)
    return model, tokenizer, preprocess_train, preprocess_val

# define the LightningModule
class Model(pl.LightningModule):
    def __init__(self, lseg):
        super().__init__()
        self.lseg = lseg
        # self.biomed_clip_model, self.biomed_clip_tokenizer, self.preprocess_train, self.preprocess_val = load_biomed_clip(device)
        # image_features, text_features, logit_scale = biomed_clip_model(images, texts)

    def train_dataloader(self):
        # loader_lseg = self.lseg.train_dataloader()
        # loader_biomed_clip = None # load the biomed clip data

        # return {"lseg": loader_lseg, "biomed_clip": loader_biomed_clip}
        return self.lseg.train_dataloader()

    def training_step(self, batch, batch_idx):
        # # access a dictionary with a batch from each DataLoader
        # batch_lseg = batch["lseg"]
        # batch_biomed_clip = batch["biomed_clip"]

        # seg_loss = self.lseg.training_step(batch_lseg, batch_idx)
        # return seg_loss
        # # adapt_loss = adapt_loss(batch_biomed_clip)
        return self.lseg.training_step(batch, batch_idx)

    def configure_optimizers(self):
        return self.lseg.configure_optimizers()

    # def adapt_loss(self, batch_biomed_clip):
    #     # get the image and text features from biomed clip
    #     # get the image and text features from lseg
    #     # compute the loss between the two

In [None]:
model = Model(lseg)

trainer.fit(model)