In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
from tqdm.notebook import tqdm
import random
import matplotlib.pyplot as plt


In [18]:

# Define a class for applying random style deformations to images
class RandomStyleDeformation:
    def __init__(self, deformation_list):
        self.deformation_list = deformation_list

    def __call__(self, image):
        # Randomly select a style deformation function from the list
        deformation_func = random.choice(self.deformation_list)
        
        # Apply the selected deformation to the image
        return deformation_func(image)

# Define a random rotation style deformation function
def style_deformation_1(image):
    # Apply a random color jitter transformation
    color_jitter = transforms.ColorJitter(
        brightness=random.uniform(0.8, 1.2),
        contrast=random.uniform(0.8, 1.2),
        saturation=random.uniform(0.8, 1.2),
        hue=random.uniform(0.1, 0.2)
    )
    image = color_jitter(image)
    return image

# Define a random brightness adjustment style deformation function
def style_deformation_2(image):
    # Apply a random brightness adjustment
    brightness_factor = random.uniform(0.8, 1.2)
    image = transforms.functional.adjust_brightness(image, brightness_factor)
    return image

# List of style deformation functions
deformation_list = [style_deformation_1, style_deformation_2]

# Create a RandomStyleDeformation object
random_style_deformation = RandomStyleDeformation(deformation_list)


In [19]:

# Define the model architecture (e.g., 'efficientnet_b0' for the smallest EfficientNet)
model_name = 'efficientnet_b0'

# Load the pre-trained EfficientNet-B0 model from a local file
model_file = model_name + '.pth'
efficientnet_b0 = torch.load(model_file).eval()

# Set the number of classes in your task (adjust as needed)
k = 16
out_features = k * k

# Replace the final classifier with a new fully connected layer
in_features = efficientnet_b0.classifier.in_features
# Replace the final classifier with two branches, each outputting a vector
efficientnet_b0.classifier = nn.Linear(in_features, 2 * out_features)  # 2 stands for two vectors, r and d


In [20]:

# Define image transformations for preprocessing
downsample_size = (224, 224)  # Desired downsampled size
image_transform = transforms.Compose([
    transforms.Resize(downsample_size),  # Downsample the image to match model input size
    transforms.ToTensor(),  # Convert the image to a PyTorch tensor
])

# Define a custom dataset for your test images
class TestImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        if self.transform:
            image = self.transform(image)
        return image

# Path to your test images
image_paths = ['data/1674921468776855 (2).jpeg']

# Define the height and width for resizing
height, width = 1536, 1536

# Define image transformations (you can adjust these as needed)
image_transform = transforms.Compose([
    transforms.Resize((height, width)),  # Resize your images to the desired size
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    # Add more transformations as necessary, e.g., normalization
])

# Create a DataLoader for your test images
test_dataset = TestImageDataset(image_paths, transform=image_transform)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)  # Batch size 1 to process one image at a time


In [21]:

# Define the MatrixProductModel
class MatrixProductModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MatrixProductModel, self).__init__()
        
        self.R = nn.Parameter(torch.randn(input_dim, output_dim), requires_grad=True)
        self.Q = nn.Parameter(torch.randn(output_dim, input_dim), requires_grad=True)

    def forward(self, x, T):
        # Perform matrix-vector multiplication Q^T * T * R * x
        y =  torch.matmul(torch.matmul(torch.matmul(x, self.R), T), self.Q)
        return y

# Custom loss function that combines L1 norms and L2 norm
def custom_loss(Z_1, Z_2, Y_1, Y_2, I_1, I_2, coef_l):
    # Calculate L2 norm
    l2_norm = F.mse_loss(Z_1, Z_2)

    # Calculate L1 norms
    l1_norm1 = F.l1_loss(Y_1, I_1)
    l1_norm2 = F.l1_loss(Y_2, I_2)

    # Combine the L1 norms as needed
    loss = coef_l * l2_norm + l1_norm1 + l1_norm2
    
    return loss


In [22]:

# Lambda coefficient
coef_l = 10

# Initialize the model and optimizer with Adam
model_n = MatrixProductModel(3, k)
model_s = MatrixProductModel(3, k)

optimizer_n = optim.Adam(model_n.parameters(), lr=0.01)
optimizer_s = optim.Adam(model_s.parameters(), lr=0.01)

# Training loop
num_epochs = 100


In [24]:

for epoch in tqdm(range(num_epochs)):
    efficientnet_b0.train() 
    model_n.train()
    model_s.train()

    for image in test_dataloader:

        # Get random style transformations
        x1 = random_style_deformation(image)
        x2 = random_style_deformation(image)

        # Get T from the encoder 
        downsampled_x1 = F.interpolate(x1, size=downsample_size, mode='bilinear', align_corners=False)
        downsampled_x2 = F.interpolate(x2, size=downsample_size, mode='bilinear', align_corners=False)

        # Output features = 256, so k is fixed to 16
        # d_i, r_i - normalized color space, color style
        d_1, r_1 = efficientnet_b0(downsampled_x1).chunk(2, dim=1)
        d_2, r_2 = efficientnet_b0(downsampled_x1).chunk(2, dim=1)

        optimizer_n.zero_grad()
        optimizer_s.zero_grad()

        # Get normalized color space pictures
        Z_1 = model_n(x1.reshape(-1, 3), d_1.reshape(k, k))
        Z_2 = model_n(x2.reshape(-1, 3), d_2.reshape(k, k))

        # Train them to get the same color style
        Y_1 = model_s(Z_1, r_2.reshape(k, k))
        Y_2 = model_s(Z_2, r_1.reshape(k, k))

        # Calculate loss, coef lambda = 10
        loss = custom_loss(Z_1, Z_2, Y_1, Y_2, x1.reshape(-1, 3), x2.reshape(-1, 3), coef_l)
        loss.backward()

        optimizer_n.step()
        optimizer_s.step()
        
        # Print the loss value at each step
        print("Epoch {}: Loss = {:.4f}".format(epoch, loss.item()))


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 0: Loss = 30.6618
Epoch 1: Loss = 1.7765
Epoch 2: Loss = 1.8757
Epoch 3: Loss = 9.6255
Epoch 4: Loss = 7.5265
Epoch 5: Loss = 17.8318
Epoch 6: Loss = 3.7520
Epoch 7: Loss = 19.9365
Epoch 8: Loss = 7.7490
Epoch 9: Loss = 0.9695
Epoch 10: Loss = 5.6036
Epoch 11: Loss = 4.2850
Epoch 12: Loss = 2.2225
Epoch 13: Loss = 1.4507
Epoch 14: Loss = 3.4479
Epoch 15: Loss = 1.1582
Epoch 16: Loss = 1.4837
Epoch 17: Loss = 0.6115
Epoch 18: Loss = 0.4796
Epoch 19: Loss = 4.2852
Epoch 20: Loss = 0.6729
Epoch 21: Loss = 1.5469
Epoch 22: Loss = 0.5737
Epoch 23: Loss = 3.1638
Epoch 24: Loss = 0.2474
Epoch 25: Loss = 16.0829
Epoch 26: Loss = 4.1785
Epoch 27: Loss = 2.3873
Epoch 28: Loss = 0.1305
Epoch 29: Loss = 0.0660


KeyboardInterrupt: 