CUE Conflict Implementation from https://github.com/irasin/Pytorch_AdaIN/blob/master/model.py

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19, VGG19_Weights
import random
import os
import torchvision
from PIL import Image

In [17]:
def calc_mean_std(features):
    """

    :param features: shape of features -> [batch_size, c, h, w]
    :return: features_mean, feature_s: shape of mean/std ->[batch_size, c, 1, 1]
    """

    batch_size, c = features.size()[:2]
    features_mean = features.reshape(batch_size, c, -1).mean(dim=2).reshape(batch_size, c, 1, 1)
    features_std = features.reshape(batch_size, c, -1).std(dim=2).reshape(batch_size, c, 1, 1) + 1e-6
    return features_mean, features_std


def adain(content_features, style_features):
    """
    Adaptive Instance Normalization

    :param content_features: shape -> [batch_size, c, h, w]
    :param style_features: shape -> [batch_size, c, h, w]
    :return: normalized_features shape -> [batch_size, c, h, w]
    """
    content_mean, content_std = calc_mean_std(content_features)
    style_mean, style_std = calc_mean_std(style_features)
    normalized_features = style_std * (content_features - content_mean) / content_std + style_mean
    return normalized_features


class VGGEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features
        self.slice1 = vgg[: 2]
        self.slice2 = vgg[2: 7]
        self.slice3 = vgg[7: 12]
        self.slice4 = vgg[12: 21]
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, images, output_last_feature=False):
        h1 = self.slice1(images)
        h2 = self.slice2(h1)
        h3 = self.slice3(h2)
        h4 = self.slice4(h3)
        if output_last_feature:
            return h4
        else:
            return h1, h2, h3, h4


class RC(nn.Module):
    """A wrapper of ReflectionPad2d and Conv2d"""
    def __init__(self, in_channels, out_channels, kernel_size=3, pad_size=1, activated=True):
        super().__init__()
        self.pad = nn.ReflectionPad2d((pad_size, pad_size, pad_size, pad_size))
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
        self.activated = activated

    def forward(self, x):
        h = self.pad(x)
        h = self.conv(h)
        if self.activated:
            return F.relu(h)
        else:
            return h


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.rc1 = RC(512, 256, 3, 1)
        self.rc2 = RC(256, 256, 3, 1)
        self.rc3 = RC(256, 256, 3, 1)
        self.rc4 = RC(256, 256, 3, 1)
        self.rc5 = RC(256, 128, 3, 1)
        self.rc6 = RC(128, 128, 3, 1)
        self.rc7 = RC(128, 64, 3, 1)
        self.rc8 = RC(64, 64, 3, 1)
        self.rc9 = RC(64, 3, 3, 1, False)

    def forward(self, features):
        h = self.rc1(features)
        h = F.interpolate(h, scale_factor=2)
        h = self.rc2(h)
        h = self.rc3(h)
        h = self.rc4(h)
        h = self.rc5(h)
        h = F.interpolate(h, scale_factor=2)
        h = self.rc6(h)
        h = self.rc7(h)
        h = F.interpolate(h, scale_factor=2)
        h = self.rc8(h)
        h = self.rc9(h)
        return h


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg_encoder = VGGEncoder()
        self.decoder = Decoder()

    def generate(self, content_images, style_images, alpha=1.0):
        content_features = self.vgg_encoder(content_images, output_last_feature=True)
        style_features = self.vgg_encoder(style_images, output_last_feature=True)
        t = adain(content_features, style_features)
        t = alpha * t + (1 - alpha) * content_features
        out = self.decoder(t)
        return out

    def calc_content_loss(out_features, t):
        return F.mse_loss(out_features, t)

    def calc_style_loss(content_middle_features, style_middle_features):
        loss = 0
        for c, s in zip(content_middle_features, style_middle_features):
            c_mean, c_std = calc_mean_std(c)
            s_mean, s_std = calc_mean_std(s)
            loss += F.mse_loss(c_mean, s_mean) + F.mse_loss(c_std, s_std)
        return loss

    def forward(self, content_images, style_images, alpha=1.0, lam=10):
        content_features = self.vgg_encoder(content_images, output_last_feature=True)
        style_features = self.vgg_encoder(style_images, output_last_feature=True)
        t = adain(content_features, style_features)
        t = alpha * t + (1 - alpha) * content_features
        out = self.decoder(t)

        output_features = self.vgg_encoder(out, output_last_feature=True)
        output_middle_features = self.vgg_encoder(out, output_last_feature=False)
        style_middle_features = self.vgg_encoder(style_images, output_last_feature=False)

        loss_c = self.calc_content_loss(output_features, t)
        loss_s = self.calc_style_loss(output_middle_features, style_middle_features)
        loss = loss_c + lam * loss_s
        return loss

In [None]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
# Transforms
cue_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

to_pil = torchvision.transforms.ToPILImage()

# STL-10 test dataset
cue_test_dataset = torchvision.datasets.STL10(root="./data/stl", train=False, download=False, transform=cue_transform)
cue_test_loader = torch.utils.data.DataLoader(cue_test_dataset, batch_size=16, shuffle=False)

# Custom style dataset
class StyleDataset(torch.utils.data.Dataset):
    def __init__(self, style_folder, transform=None):
        self.files = [os.path.join(style_folder, f)
                      for f in os.listdir(style_folder)
                      if f.lower().endswith(('png','jpg','jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

style_dataset = StyleDataset("./data/styles", transform=cue_transform)

# Load pretrained model
model = Model()
state_dict = torch.load("model_state.pth", map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.to(device)
model.eval()

# Output folder
output_root = "cue_conflict_dataset"
os.makedirs(output_root, exist_ok=True)

# Generate cue-conflict images
with torch.no_grad():
    for i, (content_batch, labels) in enumerate(cue_test_loader):
        content_batch = content_batch.to(device)

        # pick a random style for this batch
        style_img = random.choice(style_dataset).unsqueeze(0).to(device)
        style_batch = style_img.expand(content_batch.size(0), -1, -1, -1)

        # generate cue-conflict outputs
        outputs = model.generate(content_batch, style_batch, alpha=0.1)

        # save each output in a subfolder with its label
        for j, out_tensor in enumerate(outputs):
            out_tensor = out_tensor.cpu().clamp(0,1)
            out_img = to_pil(out_tensor)
            label = labels[j].item()
            
            # create subfolder for this label
            label_folder = os.path.join(output_root, str(label))
            os.makedirs(label_folder, exist_ok=True)
            
            # save image
            filename = os.path.join(label_folder, f"content_{i*cue_test_loader.batch_size + j}.png")
            out_img.save(filename)

print("Cue-conflict dataset created at ./cue_conflict_dataset")

Cue-conflict dataset created at ./cue_conflict_dataset
