In [1]:
import torchvision
import torch
import train
from metrics import StreamSegMetrics
from utils import ext_transforms as et
import datasets.acdc
from torch.utils import data
from tqdm import tqdm
from datasets import Cityscapes
from torchvision import transforms

In [2]:
device = 'cuda:0' if torch.cuda.is_available else 'cpu' #device_ids need to be adjusted

In [3]:
NUM_CLASSES=19

In [4]:
model = torchvision.models.segmentation.deeplabv3_resnet50(num_classes=NUM_CLASSES)
model = torch.nn.DataParallel(model, device_ids=[2,1]) #device_ids need to be adjusted
model.to(device) 
checkpoint = torch.load('MODEL_WEIGHTS')
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [5]:
train_transform = et.ExtCompose([
    et.ExtRandomScale(scale_range=(0.5, 2.0)),
    et.ExtRandomCrop(size=(512,1024), pad_if_needed=True),
    et.ExtRandomHorizontalFlip(),
    et.ExtToTensor(),
    et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
])

to_tensor = et.ExtCompose([et.ExtToTensor()])

val_transform = et.ExtCompose([
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
])

In [7]:
WEATHER = 'fog' #ACDC Subset

In [8]:
validation_set = datasets.acdc.Acdc(r"PATH TO ACDC SET", 'val', 'semantic',transform=val_transform, weather=WEATHER)
BATCH_SIZE = 1
val_loader = data.DataLoader(
        validation_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=8,
        drop_last=True)

In [9]:
tensor_set = datasets.acdc.Acdc(r"PATH TO ACDC SET", 'train', 'semantic',transform=to_tensor, weather=WEATHER)
BATCH_SIZE = 1
tensor_loader = data.DataLoader(
        tensor_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=8,
        drop_last=True)

In [9]:
NUM_CLASSES = 19
metrics = StreamSegMetrics(n_classes=NUM_CLASSES)
metrics.reset()

In [68]:
train.validate(model,device,val_loader,metrics)

100%|██████████| 100/100 [00:34<00:00,  2.90it/s]


In [10]:
tr_transform_adapt = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((512,1024)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

In [11]:
NUM_SAMPLES = 80

In [12]:
NR_AUGMENTS = 8

In [13]:
def get_adaption_inputs(img, tr_transform_adapt, device):
    img = img.squeeze(0)
    inputs = [(tr_transform_adapt(img)) for _ in range(NR_AUGMENTS)]
    inputs = torch.stack(inputs)
    #inputs_ssh, _ = rotation.rotate_batch(inputs, 'rand')
    #inputs_ssh = inputs_ssh.to(device, non_blocking=True)
    #inputs_ssh /= 255
    return inputs

In [None]:
mom_pre = 0.1
DECAY_FACTOR = 0.94
MIN_MOMENTUM_CONSTANT = 0.005

model.to(device)
model.eval()
results = []
train.validate(model,device,val_loader,metrics)
results.append(metrics.get_results()["Mean IoU"])

with torch.no_grad():
    for i in tqdm(range(NUM_SAMPLES)):
        image, _ = next(iter(tensor_loader))
        mom_new = (mom_pre * DECAY_FACTOR)
        for m in model.modules():
            if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
                m.momentum = mom_new + MIN_MOMENTUM_CONSTANT
                m.train()

        mom_pre = mom_new

        #augment
        inputs = get_adaption_inputs(image, tr_transform_adapt, device)

        #forward pass
        _ = model(inputs)

        #statistics
        train.validate(model,device,val_loader,metrics)
        results.append(metrics.get_results()["Mean IoU"])
        print(metrics.get_results()["Mean IoU"])

saving


In [15]:
model_path = 'adapted_models/resnet_50_CS_fog'
torch.save({'model': model.state_dict()}, model_path)