In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader, ConcatDataset
from torchvision.models.resnet import ResNet18_Weights
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.models import resnet18
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE 
from sklearn.cluster import DBSCAN, KMeans
from sklearn.neighbors import NearestNeighbors
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


In [3]:
means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
transf = transforms.Compose([
                              transforms.CenterCrop(224),  # Crops a central square patch of the image 224 because torchvision's AlexNet needs a 224x224 input!
                              transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                              transforms.Normalize(means,stds) # Normalizes tensor with mean and standard deviation
])

def photo_transform(data):
    transf_data = transf(data)
    transf_data.domain_id = 1
    return transf_data

def art_transform(data):
    transf_data = transf(data)
    transf_data.domain_id = 2
    return transf_data

def cartoon_transform(data):
    transf_data = transf(data)
    transf_data.domain_id = 3
    return transf_data

def sketch_transform(data):
    transf_data = transf(data)
    transf_data.domain_id = 4
    return transf_data

device = 'cuda:1'
num_epoch = 50

In [4]:
from matplotlib.patches import Ellipse

def draw_ellipse(position, covariance, ax=None, **kwargs):
    """Draw an ellipse with a given position and covariance"""
    ax = ax or plt.gca()
    
    # Convert covariance to principal axes
    if covariance.shape == (2, 2):
        U, s, Vt = np.linalg.svd(covariance)
        angle = np.degrees(np.arctan2(U[1, 0], U[0, 0]))
        width, height = 2 * np.sqrt(s)
    else:
        angle = 0
        width, height = 2 * np.sqrt(covariance)
    
    # Draw the ellipse
    for nsig in range(1, 4):
        ax.add_patch(Ellipse(position, nsig * width, nsig * height,
                             angle, **kwargs))

def plot_gmm(gmm, X, label=True, ax=None):
    ax = ax or plt.gca()
    labels = gmm.fit(X).predict(X)
    if label:
        ax.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis', zorder=2)
    else:
        ax.scatter(X[:, 0], X[:, 1], s=40, zorder=2)
    ax.axis('equal')
    
    w_factor = 0.2 / gmm.weights_.max()
    for pos, covar, w in zip(gmm.means_, gmm.covariances_, gmm.weights_):
        draw_ellipse(pos, covar, alpha=w * w_factor)

In [5]:
dir_photo = '../data/pacs_v1.0/photo/'
dir_art = '../data/pacs_v1.0/art_painting/'
dir_cartoon = '../data/pacs_v1.0/cartoon/'
dir_sketch = '../data/pacs_v1.0/sketch/'

photo_dataset = ImageFolder(dir_photo, transform=photo_transform)
art_dataset = ImageFolder(dir_art, transform=art_transform)
cartoon_dataset = ImageFolder(dir_cartoon, transform=cartoon_transform)
sketch_dataset = ImageFolder(dir_sketch, transform=sketch_transform)

In [6]:
print(f"Photo Dataset: {len(photo_dataset)}")
print(f"Art Dataset: {len(art_dataset)}")
print(f"Cartoon Dataset: {len(cartoon_dataset)}")
print(f"Sketch Dataset: {len(sketch_dataset)}")

Photo Dataset: 1670
Art Dataset: 2048
Cartoon Dataset: 2344
Sketch Dataset: 3929


In [8]:
photo_train_dataset, photo_test_dataset = torch.utils.data.random_split(photo_dataset, [1336, 334])
art_train_dataset, art_test_dataset = torch.utils.data.random_split(art_dataset, [1638, 410])
cartoon_train_dataset, cartoon_test_dataset = torch.utils.data.random_split(cartoon_dataset, [1875, 469])
sketch_train_dataset, sketch_test_dataset = torch.utils.data.random_split(sketch_dataset, [3143, 786])

# concated_train_dataset = ConcatDataset([photo_train_dataset, art_train_dataset, cartoon_train_dataset, sketch_train_dataset])
# concated_test_dataset = ConcatDataset([photo_test_dataset, art_test_dataset, cartoon_test_dataset, sketch_test_dataset])
concated_train_dataset = ConcatDataset([art_train_dataset, sketch_train_dataset, photo_train_dataset])
# concated_test_dataset = ConcatDataset([sketch_test_dataset, art_test_dataset, cartoon_test_dataset])
concated_train_domain = torch.vstack((torch.full((len(art_dataset), 1), 1), torch.full((len(sketch_dataset), 1), 2), torch.full((len(photo_dataset), 1), 3)))
train_loader = DataLoader(list(zip(concated_train_dataset, concated_train_domain)), batch_size=32, shuffle=True, num_workers=8)
test_loader = DataLoader(cartoon_dataset, batch_size=32, shuffle=False, num_workers=8)
# test_loader = DataLoader(concated_test_dataset, batch_size=32, shuffle=False, num_workers=4)

## Training phase

In [9]:
class MixStyle(nn.Module):
    """MixStyle.
    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix='random'):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
          mix (str): how to mix.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha
        self.mix = mix
        self._activated = True

    def __repr__(self):
        return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})'

    def set_activation_status(self, status=True):
        self._activated = status

    def update_mix_method(self, mix='random'):
        self.mix = mix

    def forward(self, x):
        if not self.training or not self._activated:
            return x

        if random.random() > self.p:
            return x

        B = x.size(0)

        mu = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig

        lmda = self.beta.sample((B, 1, 1, 1))
        lmda = lmda.to(x.device)

        if self.mix == 'random':
            # random shuffle
            perm = torch.randperm(B)

        elif self.mix == 'crossdomain':
            # split into two halves and swap the order
            perm = torch.arange(B - 1, -1, -1) # inverse index
            perm_b, perm_a = perm.chunk(2)
            perm_b = perm_b[torch.randperm(B // 2)]
            perm_a = perm_a[torch.randperm(B // 2)]
            perm = torch.cat([perm_b, perm_a], 0)

        else:
            raise NotImplementedError

        mu2, sig2 = mu[perm], sig[perm]
        mu_mix = mu*lmda + mu2 * (1-lmda)
        sig_mix = sig*lmda + sig2 * (1-lmda)

        return x_normed*sig_mix + mu_mix

In [None]:
class DistributionUncertainty(nn.Module):
    """
    Distribution Uncertainty Module
        Args:
        p   (float): probabilty of foward distribution uncertainty module, p in [0,1].

    """

    def __init__(self, p=0.5, eps=1e-6):
        super(DistributionUncertainty, self).__init__()
        self.eps = eps
        self.p = p
        self.factor = 1.0

    def _reparameterize(self, mu, std):
        epsilon = torch.randn_like(std) * self.factor
        return mu + epsilon * std

    def sqrtvar(self, x):
        t = (x.var(dim=0, keepdim=True) + self.eps).sqrt()
        t = t.repeat(x.shape[0], 1)
        return t

    def forward(self, x):
        if (not self.training) or (np.random.random()) > self.p:
            return x

        mean = x.mean(dim=[2, 3], keepdim=False)
        std = (x.var(dim=[2, 3], keepdim=False) + self.eps).sqrt()

        sqrtvar_mu = self.sqrtvar(mean)
        sqrtvar_std = self.sqrtvar(std)

        beta = self._reparameterize(mean, sqrtvar_mu)
        gamma = self._reparameterize(std, sqrtvar_std)

        x = (x - mean.reshape(x.shape[0], x.shape[1], 1, 1)) / std.reshape(x.shape[0], x.shape[1], 1, 1)
        x = x * gamma.reshape(x.shape[0], x.shape[1], 1, 1) + beta.reshape(x.shape[0], x.shape[1], 1, 1)

        return x

In [10]:
class ConstantStyle(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.mean = []
        self.std = []
        self.eps = eps
        self.const_mean = None
        self.const_std = None
        self.domain_list = []
    
    def clear_memory(self):
        self.mean = []
        self.std = []
        self.domain_list = []
        
    def get_style(self, x):
        mu = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True)
        var = var.sqrt()
        mu, var = mu.detach().squeeze(), var.detach().squeeze()
        
        return mu, var
    
    def store_style(self, x, domains):
        mu, var = self.get_style(x)
        self.mean.extend(mu)
        self.std.extend(var)
        self.domain_list.extend([i.item() for i in domains])
    
    def clustering(self, round):
        mean = torch.vstack(self.mean)
        std = torch.vstack(self.std)
        tsne = TSNE(n_components=1, random_state=42)
        transformed_mean = tsne.fit_transform(mean.detach().cpu().numpy())

        tsne2 = TSNE(n_components=1, random_state=42)
        transformed_std = tsne2.fit_transform(std.detach().cpu().numpy())
        plt.cla()
        plt.clf()
        plt.scatter(transformed_mean[:, 0], transformed_std[:, 0])
        plt.savefig(f'mean_std_round{round}.png')
        
        data = torch.cat((mean, std), dim=1).detach().cpu().numpy()
        # neigh = NearestNeighbors(n_neighbors=2)
        # nbrs = neigh.fit(data)
        # distances, indices = nbrs.kneighbors(data)
        # distances = np.sort(distances, axis=0)
        # distances = distances[:,1]
        # plt.figure(figsize=(20,10))
        # plt.plot(distances)
        dbscan = DBSCAN(eps=5, min_samples=50)
        # dbscan = KMeans(n_clusters=3, n_init=50)
        dbscan.fit(data)
        
        labels = dbscan.labels_
        n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
        # print(f'Total cluster: {n_clusters}')
        
        sample_each_label = [len(labels[labels == i]) for i in range(n_clusters)]
        largest_cluster = np.argmax(sample_each_label)
        cluster_mean = mean[labels == largest_cluster]
        cluster_std = std[labels == largest_cluster]
        self.const_mean = torch.mean(cluster_mean, axis=0)
        self.const_std = torch.mean(cluster_std, axis=0)
    
    def cal_mean_std(self, id):
        domain_list = np.array(self.domain_list)
        idx_val = np.where(domain_list == id)[0]
        cluster_mean = [self.mean[i] for i in idx_val]
        cluster_std = [self.std[i] for i in idx_val]
        cluster_mean = torch.stack(cluster_mean)
        cluster_std = torch.stack(cluster_std)
        
        self.const_mean = torch.mean(cluster_mean, axis=0)
        self.const_std = torch.mean(cluster_std, axis=0)
            
    
    def forward(self, x, test=False):
        mu = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig
        const_mean = torch.reshape(self.const_mean, (1, self.const_mean.shape[0], 1, 1))
        const_std = torch.reshape(self.const_std, (1, self.const_std.shape[0], 1, 1))
        out = x_normed * const_std + const_mean
        return out
        

In [11]:
class StyleIntergratedModel(nn.Module):
    def __init__(self, num_style=2):
        super().__init__()
        model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.model = model
        self.mixstyle = MixStyle(p=0.8, alpha=0.1)
        self.num_style = num_style
        self.conststyle = [ConstantStyle() for i in range(self.num_style)]
        self.mean = []
        self.std = []
        self.const_mean = None
        self.const_std = None
    
    def forward(self, x, domains, const_style=False, store_style=False, test=False):
        x = self.model.conv1(x)
        # x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        if store_style:
            self.conststyle[0].store_style(x, domains)
        if const_style:
            x = self.conststyle[0](x, test=test)
        x = self.model.layer2(x)
        if store_style:
            self.conststyle[1].store_style(x, domains)
        if const_style:
            x = self.conststyle[1](x, test=test)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.model.fc(x)

        return x


In [12]:
class MixStyleModel(nn.Module):
    def __init__(self, num_style=2):
        super().__init__()
        model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.model = model
        self.mixstyle = MixStyle(p=0.5, alpha=0.1)
        self.num_style = num_style
        self.mean = []
        self.std = []
        self.const_mean = None
        self.const_std = None
    
    def forward(self, x):
        x = self.model.conv1(x)
        # x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.mixstyle(x)
        x = self.model.layer2(x)
        x = self.mixstyle(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.model.fc(x)

        return x

In [13]:
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

model = StyleIntergratedModel()
model.model.fc = torch.nn.Linear(model.model.fc.in_features, 7)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

model2 = MixStyleModel()
model2.model.fc = torch.nn.Linear(model2.model.fc.in_features, 7)
criterion2 = nn.CrossEntropyLoss()
optimizer2 = optim.Adam(model2.parameters(), lr=1e-4, weight_decay=1e-5)

In [16]:
model.to(device)
model2.to(device)
stored_label = []
for epoch in range(num_epoch):
    for conststyle in model.conststyle:
        conststyle.clear_memory()
        
    model.train()
    model2.train()
    running_loss = 0.0
    running_loss2 = 0.0
    for inputs, labels in train_loader:
        inputs, labels, domains = inputs[0].to(device), inputs[1].to(device), labels
        optimizer.zero_grad()
        optimizer2.zero_grad()
            
        stored_label.extend(labels.detach().cpu())
        
        
        if epoch == 0:
            outputs = model(inputs, domains, store_style=True)
        else:
            outputs = model(inputs, domains, const_style=True, store_style=True)
        
        outputs2 = model2(inputs)
        
        loss = criterion(outputs, labels)
        loss2 = criterion2(outputs2, labels)
        
        loss.backward()
        loss2.backward()
        optimizer.step()
        optimizer2.step()

        running_loss += loss.item()
        running_loss2 += loss2.item()
    
    if epoch % 10 == 0:
        for conststyle in model.conststyle:
            conststyle.cal_mean_std(2)

    print(f"Epoch {epoch+1}/{num_epoch}, ConstStyle Loss: {running_loss/len(train_loader)} | MixStyle Loss: {running_loss2/len(train_loader)}")

    model.eval()
    model2.eval()
    correct_predictions = 0
    correct_predictions2 = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs, domains, const_style=True, test=True)
            outputs2 = model2(inputs)

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            
            _, predicted2 = torch.max(outputs2, 1)
            correct_predictions2 += (predicted2 == labels).sum().item()

    # Calculate test accuracy
    test_accuracy = correct_predictions / total_samples
    test_accuracy2 = correct_predictions2 / total_samples
    print(f"ConstStyle Accuracy: {test_accuracy * 100:.2f}% | MixStyle Accuracy: {test_accuracy2 * 100:.2f}%")

print("Training finished")


Epoch 1/50, ConstStyle Loss: 0.1305159519057876 | MixStyle Loss: 0.21563147540049007
ConstStyle Accuracy: 70.39% | MixStyle Accuracy: 67.24%
Epoch 2/50, ConstStyle Loss: 0.13874749200961864 | MixStyle Loss: 0.11575003642550048
ConstStyle Accuracy: 75.17% | MixStyle Accuracy: 74.49%
Epoch 3/50, ConstStyle Loss: 0.045971750987519044 | MixStyle Loss: 0.06834541510033887
ConstStyle Accuracy: 74.79% | MixStyle Accuracy: 71.93%
Epoch 4/50, ConstStyle Loss: 0.013806903444371224 | MixStyle Loss: 0.0529746399843134
ConstStyle Accuracy: 76.54% | MixStyle Accuracy: 72.99%
Epoch 5/50, ConstStyle Loss: 0.005755505637656218 | MixStyle Loss: 0.04008968979906058
ConstStyle Accuracy: 76.96% | MixStyle Accuracy: 69.71%
Epoch 6/50, ConstStyle Loss: 0.0033775337509117285 | MixStyle Loss: 0.04303525521148307
ConstStyle Accuracy: 77.69% | MixStyle Accuracy: 70.61%
Epoch 7/50, ConstStyle Loss: 0.003541156085248076 | MixStyle Loss: 0.035517664236977
ConstStyle Accuracy: 77.60% | MixStyle Accuracy: 72.31%
Epoc

### Normal training

In [None]:
# model.to(device)
# stored_label = []
# for epoch in range(num_epoch):        
#     model.train()
#     running_loss = 0.0
#     for inputs, labels in train_loader:
#         inputs, labels = inputs.to(device), labels.to(device)

#         optimizer.zero_grad()
            
#         stored_label.extend(labels.detach().cpu())
#         outputs = model(inputs)
            
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()

#     print(f"Epoch {epoch+1}/{num_epoch}, Loss: {running_loss/len(train_loader)}")

#     model.eval()
#     correct_predictions = 0
#     total_samples = 0

#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)

#             # Forward pass
#             outputs = model(inputs)

#             # Calculate accuracy
#             _, predicted = torch.max(outputs, 1)
#             total_samples += labels.size(0)
#             correct_predictions += (predicted == labels).sum().item()

#     # Calculate test accuracy
#     test_accuracy = correct_predictions / total_samples
#     print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

# print("Training finished")

### Train and plot domain

In [None]:
classes = ['sketch', 'art', 'cartoon']
scatter = plt.scatter(transformed_mean[:, 0], transformed_std[:, 0], c=stored_label)
plt.legend(handles=scatter.legend_elements()[0], labels=classes)
plt.show();

In [None]:
class IdentityLayer(nn.Module):
    def __init__(self, in_features):
        super(IdentityLayer, self).__init__()
        self.fc = nn.Identity()

    def forward(self, x):
        out = self.fc(x)
        return out

In [None]:
model.layer2 = nn.Identity()
model.layer3 = nn.Identity()
model.layer4 = nn.Identity()
model.avgpool = nn.Identity()
model.fc = nn.Identity()

In [None]:
photo_image_idx = [[] for _ in range(7)]
art_image_idx = [[] for _ in range(7)]
cartoon_image_idx = [[] for _ in range(7)]
sketch_image_idx = [[] for _ in range(7)]

In [None]:
for i in range(7):
    for idx, val in enumerate(photo_dataset.targets):
        if val == i:
            photo_image_idx[i].append(idx)
    for idx, val in enumerate(art_dataset.targets):
        if val == i:
            art_image_idx[i].append(idx)
    for idx, val in enumerate(cartoon_dataset.targets):
        if val == i:
            cartoon_image_idx[i].append(idx)
    for idx, val in enumerate(sketch_dataset.targets):
        if val == i:
            sketch_image_idx[i].append(idx)

### Get random 50 samples of all domains

In [None]:
class_dl_list = []
for i in range(7):
    photo_idx = np.random.choice(photo_image_idx[i], 50)
    art_idx = np.random.choice(art_image_idx[i], 50)
    cartoon_idx = np.random.choice(cartoon_image_idx[i], 50)
    sketch_idx = np.random.choice(sketch_image_idx[i], 50)
    
    photo_subset = Subset(photo_dataset, photo_idx)
    art_subset = Subset(art_dataset, art_idx)
    cartoon_subset = Subset(cartoon_dataset, cartoon_idx)
    sketch_subset = Subset(sketch_dataset, sketch_idx)
    
    class_dataset = ConcatDataset([photo_subset, art_subset, cartoon_subset, sketch_subset])
    class_dl = DataLoader(class_dataset, batch_size=50, shuffle=False)
    class_dl_list.append(class_dl)

### Get random 50 samples of testing domain

In [None]:
# subset_list = []
# for i in range(7):
#     sketch_idx = np.random.choice(sketch_image_idx[i], 50)    
#     sketch_subset = Subset(sketch_dataset, sketch_idx)
#     subset_list.append(sketch_subset)

# class_dataset = ConcatDataset(subset_list)
# class_dl = DataLoader(class_dataset, batch_size=50, shuffle=False)

### Move samples of test domain to list

In [None]:
# total_feats = []
# total_labels = []
# i = 0
# model.to(device)
# for images, labels in class_dl:
#     images, labels = images.to(device), labels.to(device)
#     feats = model(images)
#     total_labels.extend(labels.cpu().detach().numpy())
#     total_feats.append(feats.cpu().detach().numpy())

### Move samples of all domains to list

In [None]:
total_feats = []
total_labels = []
i = 0
model.to('cpu')
for images, labels in class_dl_list[6]:
    feats = model(images)
    feats = feats.reshape((50, 64, 56, 56))
    # feats = feats.reshape((50, 128, 28, 28))
    # feats = feats.reshape((50, 256, 14, 14))
    # feats = feats.reshape((50, 512, 7, 7))
    print(feats.shape)
    total_labels.extend([i for _ in range(50)])
    total_feats.append(feats.cpu().detach().numpy())
    i += 1

In [None]:
total_feats = np.vstack(total_feats)

In [None]:
# tsne = TSNE(n_components=2, random_state=42)
# transformed_feats = tsne.fit_transform(total_feats)

In [None]:
# plt.scatter(transformed_feats[:, 0], transformed_feats[:, 1], c=total_labels)
# plt.legend(['photo', 'art'])
# plt.show();

In [None]:
mean = np.mean(total_feats, axis=(2, 3), keepdims=True)
std = np.std(total_feats, axis=(2, 3), keepdims=True)

In [None]:
mean = np.squeeze(mean)
std = np.squeeze(std)

print(mean.shape, std.shape)

In [None]:
np.save('./mean_value.npy', mean)
np.save('./std_value.npy', std)

In [None]:
concat_feats = np.concatenate([mean[:, np.newaxis, :], std[:, np.newaxis, :]], axis=1)

In [None]:
concat_feats.shape

In [None]:
tsne = TSNE(n_components=1, random_state=42)
transformed_mean = tsne.fit_transform(mean)

tsne2 = TSNE(n_components=1, random_state=42)
transformed_std = tsne2.fit_transform(std)

In [None]:
classes = ['photo', 'art', 'cartoon', 'sketch']
scatter = plt.scatter(transformed_mean[:, 0], transformed_std[:, 0], c=total_labels)
plt.legend(handles=scatter.legend_elements()[0], labels=classes)
plt.show();