In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F



In [None]:
class AdversarialLoss(nn.Module):
    def __init__(self, discriminator):
        super(AdversarialLoss, self).__init__()
        self.discriminator = discriminator

    def forward(self, real_images, fake_images):
        real_logits = self.discriminator(real_images)
        fake_logits = self.discriminator(fake_images.detach())

        d_loss_real = F.binary_cross_entropy_with_logits(real_logits, torch.ones_like(real_logits))
        d_loss_fake = F.binary_cross_entropy_with_logits(fake_logits, torch.zeros_like(fake_logits))
        g_loss = F.binary_cross_entropy_with_logits(fake_logits, torch.ones_like(fake_logits))

        d_loss = d_loss_real + d_loss_fake
        return g_loss, d_loss



In [None]:
class BoundaryLoss(nn.Module):
    def __init__(self):
        super(BoundaryLoss, self).__init__()

    def forward(self, decoded_images):
        batch_size, channels, height, width = decoded_images.size()

        # Reshape images into n x n pieces
        pieces = decoded_images.view(batch_size, channels, n, -1)

        # Calculate SSIM for top-bottom relationships
        top_bottom_ssim = torch.zeros(batch_size)
        for i in range(n):
            top_piece = pieces[:, :, i, :]
            bottom_piece = pieces[:, :, (i + 1) % n, :]
            top_bottom_ssim += F.ssim(top_piece[:, :, -1, :], bottom_piece[:, :, 0, :], data_range=1)

        # Calculate SSIM for left-right relationships
        left_right_ssim = torch.zeros(batch_size)
        for j in range(n):
            left_piece = pieces[:, :, :, j]
            right_piece = pieces[:, :, :, (j + 1) % n]
            left_right_ssim += F.ssim(left_piece[:, :, :, -1], right_piece[:, :, :, 0], data_range=1)

        # Compute average SSIM scores
        top_bottom_ssim /= n
        left_right_ssim /= n

        # Boundary loss
        loss = (1 - top_bottom_ssim.mean()) + (1 - left_right_ssim.mean())
        return loss


In [None]:
#Define the Jigsaw Loss class
class JigsawLoss(nn.Module):
    def __init__(self, reference_labels):
        super(JigsawLoss, self).__init__()
        self.reference_labels = reference_labels

    def forward(self, predicted_distribution):
        batch_size, num_categories = predicted_distribution.size()

        # Compute focal loss
        focal_loss = -torch.sum((1 - predicted_distribution * self.reference_labels) ** 2 * torch.log(predicted_distribution))

        # Average over batch size
        loss = focal_loss / batch_size
        return loss

# Helper functions
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

def cut_boundaries(img, pix):
    top_boundary = img[:pix, :]
    bottom_boundary = img[-pix:, :]
    left_boundary = img[:, :pix]
    right_boundary = img[:, -pix:]
    return top_boundary, bottom_boundary, left_boundary, right_boundary

def compute_relationships(top_boundaries, bottom_boundaries, left_boundaries, right_boundaries):
    n = len(top_boundaries)
    top_bottom_relationships = np.zeros((n, n))
    left_right_relationships = np.zeros((n, n))
    
    for i in range(n):
        for j in range(n):
            if i != j:
                top_bottom_relationships[i, j] = calculate_psnr(top_boundaries[i], bottom_boundaries[j])
                left_right_relationships[i, j] = calculate_psnr(left_boundaries[i], right_boundaries[j])
    
    return top_bottom_relationships, left_right_relationships

def greedy_algorithm(top_bottom_relationships, left_right_relationships):
    n = len(top_bottom_relationships)
    selected_top_bottom = []
    selected_left_right = []
    
    for _ in range(n):
        max_top_bottom_idx = np.unravel_index(np.argmax(top_bottom_relationships), top_bottom_relationships.shape)
        max_left_right_idx = np.unravel_index(np.argmax(left_right_relationships), left_right_relationships.shape)
        
        selected_top_bottom.append(max_top_bottom_idx)
        selected_left_right.append(max_left_right_idx)
        
        top_bottom_relationships[max_top_bottom_idx] = -np.inf
        left_right_relationships[max_left_right_idx] = -np.inf
    
    return selected_top_bottom, selected_left_right

def minimum_spanning_tree(selected_top_bottom, selected_left_right):
    # Initialize variables
    n = len(selected_top_bottom)
    visited = [False] * n
    parent = [-1] * n
    key = [float('inf')] * n

    # Start with the first node
    key[0] = 0

    # Construct MST
    for _ in range(n):
        # Find the vertex with the minimum key value
        min_key = float('inf')
        min_idx = -1
        for i in range(n):
            if not visited[i] and key[i] < min_key:
                min_key = key[i]
                min_idx = i

        # Mark the selected vertex as visited
        visited[min_idx] = True

        # Update key and parent for adjacent vertices
        for j in range(n):
            if not visited[j]:
                # Update key if the weight is smaller
                if (selected_top_bottom[min_idx][0] == j or selected_left_right[min_idx][0] == j) and selected_top_bottom[min_idx][1] > selected_top_bottom[j][1]:
                    key[j] = selected_top_bottom[min_idx][1]
                    parent[j] = min_idx
                if (selected_top_bottom[min_idx][1] == j or selected_left_right[min_idx][1] == j) and selected_left_right[min_idx][0] > selected_left_right[j][0]:
                    key[j] = selected_left_right[min_idx][0]
                    parent[j] = min_idx

    # Construct the reference permutation
    reference_permutation = []
    for i in range(1, n):
        reference_permutation.append((parent[i], i))

    return reference_permutation

