In [None]:
import torch
torch.manual_seed(42)
import torch.nn as nn
import torch.optim as optim

import numpy as np
np.random.seed(42)

from make_analogies_functions import *

In [None]:
# co = {"top": 2,  "bottom": 7, "left": 7, "right": 9}
# im = create_image_with_white_rectangle(co, 10)
# print(im)
# print()
# gravity(im, co, "down")

In [None]:
num_samples = 100000 #each sample is a trio of images of which each comes in various forms (analogous transformations).
img_size = 10
all_images = []
method_names = ["Resized", "Moved", "Rotated", "Inverted", "Mirrored", "Close/Far Corners", "Close/Far Edges", "Stretched", "Shadows", "Gravity"]
seed_iteration = 0
data = []

for i in range(num_samples):    
    trios = []
    invalid_img = True
    while invalid_img:
        seed_iteration += 1
        np.random.seed(seed_iteration)

        #make 3 random images with rectangles
        top_lefts = [{"top": np.random.randint(1, img_size-1), "left": np.random.randint(1, img_size-1)} for _ in range(3)]
        bottom_rights = [{"bottom": np.random.randint(d["top"] + 1, img_size), "right": np.random.randint(d["left"] + 1, img_size)} for d in top_lefts]
        coords = [top_left | bottom_right for top_left, bottom_right in zip(top_lefts, bottom_rights)]
        trio = [create_image_with_white_rectangle(coord, img_size) for coord in coords]
   
        #sample parameters for analogies
        mirror_horizontal = np.random.choice([True, False])
        rotation_degree = np.random.choice([90, 180, 270])
        grow_left = np.random.choice([0,0,1,2])
        grow_right = np.random.choice([0,0,1,2])
        grow_top = np.random.choice([0,0,1,2])
        grow_bottom = np.random.choice([0,0,1,2])
        move_vertical = np.random.choice([0,1,2])
        move_horizontal = np.random.choice([0,1,2])
        furthest_edge = np.random.choice([True, False])
        furthest_corner = np.random.choice([True, False])
        reverse_shadows = np.random.choice([True, False])
        gravity_direction = np.random.choice(["up","down","left","right"])

        # Generate analogies
        resizes = [resize_rectangle(img, coord, grow_top, grow_bottom, grow_left, grow_right) for img, coord in zip(trio, coords)]
        moves = [move_rectangle(img, move_horizontal, move_vertical)  for img in trio]
        rotations = [rotate_image(img, rotation_degree) for img in trio]
        inversions = [invert_colors(img)  for img in trio]
        mirrors = [mirror_image(img, horizontal=mirror_horizontal) for img in trio]
        corner_cells = [paint_corner(img, furthest_corner) for img in trio]
        edges = [paint_edge(img, coord, furthest_edge) for img, coord in zip(trio, coords)]
        stretches = [stretch_rectangle(img, coord) for img, coord in zip(trio, coords)]
        shadows = [draw_shadows(img, coord, reverse_shadows)  for img, coord in zip(trio, coords)]
        gravities = [gravity(img, coord, gravity_direction)  for img, coord in zip(trio, coords)]
        
        #check whether images violate rules (original three include duplicates; initial transformations left the canvas)
        if np.array_equal(trio[0], trio[1]) or np.array_equal(trio[0], trio[2]) or np.array_equal(trio[1], trio[2]):
            invalid_img = True
        elif invalid_matrix(moves[0], img_size, img_size, 1):
            invalid_img = True
        elif invalid_matrix(resizes[0], img_size, img_size, 1):
            invalid_img = True
        else:
            invalid_img = False
    
    transformed_trios = [resizes, moves, rotations, inversions, mirrors, corner_cells, edges, stretches, shadows, gravities]
    data.append([np.stack([trio, transformed_trio]) for transformed_trio in transformed_trios])

data = np.array(data)
print(data.shape)

In [None]:
#this cell is for reformatting the df; one could e.g. slice off specific data here (e.g., for a test set with new analogies)
long_data = data.reshape(data.shape[0]*data.shape[1], data.shape[2], data.shape[3], data.shape[4], data.shape[5])
long_data = long_data.reshape(long_data.shape[0], long_data.shape[1] * long_data.shape[2], long_data.shape[3], long_data.shape[4])
plot_double_trio(long_data[8]) #first 9 on first dim are same trio and all 9 transformations; then it starts over with second trio
nonduplicates = np.unique(long_data, axis=0) #get rid of duplicated tasks; could be stricter by also considering flipped fewshot order as duplicated
print(f"{np.round(100*(1 - nonduplicates.shape[0] / long_data.shape[0]),1)}% double trios were duplicated") 

In [None]:
long_data.shape

In [None]:
with open("nonduplicates.pkl", "wb") as f:
    pickle.dump(nonduplicates, file=f)

In [None]:
import torch
from sklearn.model_selection import train_test_split

# Split into input (x) and output (y)
x_data = nonduplicates[:, :-1, :, :] / 255 # All but the last channel
y_data = nonduplicates[:, -1, :, :] / 255  # Only the last channel

# Convert to PyTorch tensors
x_data = torch.from_numpy(x_data).float()
y_data = torch.from_numpy(y_data).float()

# Split into training and testing sets
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.005, random_state=42)

# Print the shapes of the resulting tensors
print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)
print("x_test shape:", x_test.shape)
print("y_test shape:", y_test.shape)

In [None]:
#basic feed forward test
class FullyConnectedNN(nn.Module):
    def __init__(self):
        super(FullyConnectedNN, self).__init__()
        self.flatten = nn.Flatten()  
        self.fc1 = nn.Linear(5 * 10 * 10, 150) 
        self.relu1 = nn.ReLU()  
        self.fc2 = nn.Linear(150, 150) 
        self.relu2 = nn.ReLU() 
        self.fc3 = nn.Linear(150, 150) 
        self.relu3 = nn.ReLU() 
        self.fc4 = nn.Linear(150, 10 * 10) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.relu3(x)
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x.view(-1, 10, 10)  # Reshape the output to be 10x10

# Create an instance of the model
model = FullyConnectedNN()

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0)

num_epochs = 500
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    train_outputs = model(x_train)
    loss = criterion(train_outputs, y_train)

    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        model.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            test_outputs = model(x_test)
            test_loss = criterion(test_outputs, y_test)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Test: {test_loss.item():.4f}')

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Nr params: {pytorch_total_params}")

In [None]:
plot_exmp = 0
with torch.no_grad():
    y  = y_test[plot_exmp].expand(1, y_test[plot_exmp].shape[0], y_test[plot_exmp].shape[1])
    example = torch.row_stack((x_test[plot_exmp], y))
    plot_double_trio(example)
    plt.figure(figsize=(6, 2))
    plt.title("Non-binarized Output")
    plt.imshow(model(x_test[plot_exmp].unsqueeze(0))[0], cmap='gray')
    plt.axis('off')
    plt.show()