In [2]:
#make transformations valid for rectangles and Ls
import torch
torch.manual_seed(42)
import torch.nn as nn
import torch.optim as optim

import numpy as np
np.random.seed(42)

import pickle

from make_analogies_functions import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [None]:
# np.random.seed(22)
# im = create_image(10, shape="rectangle")
# print(im)
# count_pixels(im, True, True)

In [None]:
num_samples = 100 #each sample is a trio of images of which each comes in various forms (analogous transformations).
img_size = 10
pairs_per_task = 3
shape = "L" #rectangle or L
all_images = []
method_names = ["Grown", "Moved", "Rotated", "Inverted", "Mirrored", "Close/Far Corners", "Close/Far Edges", "Stretched", "Shadows", "Gravity", "Count"]
seed_iteration = 0
data = []

for i in range(num_samples):    
    trios = []
    invalid_img = True
    while invalid_img:
        seed_iteration += 1
        np.random.seed(seed_iteration)

        #make 3 random images with rectangles
        trio = [create_image(img_size = img_size, shape = shape) for _ in range(pairs_per_task)]

        #sample parameters for analogies
        mirror_horizontal = np.random.choice([True, False])
        rotation_degree = np.random.choice([90, 180, 270])
        grow_left = np.random.choice([0,0,1,2])
        grow_right = np.random.choice([0,0,1,2])
        grow_top = np.random.choice([0,0,1,2])
        grow_bottom = np.random.choice([0,0,1,2])
        move_vertical = np.random.choice([-2,-1,0,1,2])
        move_horizontal = np.random.choice([-2,-1,0,1,2])
        furthest_edge = np.random.choice([True, False])
        furthest_corner = np.random.choice([True, False])
        reverse_shadows = np.random.choice([True, False])
        gravity_direction = np.random.choice(["up","down","left","right"])
        count_left_right = np.random.choice([True, False])
        count_top_bottom = np.random.choice([True, False])

        # Generate analogies
        growths = [grow(img, grow_top, grow_bottom, grow_left, grow_right) for img in trio]
        moves = [move(img, move_horizontal, move_vertical)  for img in trio]
        rotations = [rotate_image(img, rotation_degree) for img in trio]
        inversions = [invert_colors(img)  for img in trio]
        mirrors = [mirror_image(img, horizontal=mirror_horizontal) for img in trio]
        corner_cells = [paint_corner(img, furthest_corner) for img in trio]
        edges = [paint_edge(img, furthest_edge) for img in trio]
        stretches = [stretch_rectangle(img) for img in trio]
        shadows = [draw_shadows(img, reverse_shadows)  for img in trio]
        gravities = [gravity(img, gravity_direction)  for img in trio]
        counts = [count_pixels(img, count_left_right, count_top_bottom) for img in trio]
        
        #check whether images violate rules (original three include duplicates; initial transformations left the canvas)
        if np.array_equal(trio[0], trio[1]) or np.array_equal(trio[0], trio[2]) or np.array_equal(trio[1], trio[2]):
            invalid_img = True
        elif invalid_matrix(growths[0], img_size, img_size, 1):
            invalid_img = True
        elif invalid_matrix(growths[0], img_size, img_size, 1):
            invalid_img = True
        else:
            invalid_img = False
    
    transformed_trios = [growths, moves, rotations, inversions, mirrors, corner_cells, edges, stretches, shadows, gravities, counts]
    data.append([np.stack([trio, transformed_trio]) for transformed_trio in transformed_trios])

data = np.array(data)
print(data.shape)

In [None]:
#this cell is for reformatting the df; one could e.g. slice off specific data here (e.g., for a test set with only new analogies)
long_data = data.reshape(data.shape[0]*data.shape[1], data.shape[2], data.shape[3], data.shape[4], data.shape[5])
long_data = long_data.reshape(long_data.shape[0], long_data.shape[1] * long_data.shape[2], long_data.shape[3], long_data.shape[4])
nonduplicates = np.unique(long_data, axis=0) #get rid of duplicated tasks; could be stricter by also considering flipped fewshot orders as duplicated
print(f"{np.round(100*(1 - nonduplicates.shape[0] / long_data.shape[0]),1)}% double trios were duplicated") 
print(long_data.shape)

In [None]:
for i in range(0, 2):
    plot_double_trio(long_data[i])

In [None]:
# with open("nonduplicates.pkl", "wb") as f:
#     pickle.dump(nonduplicates, file=f)

with open("nonduplicates_L.pkl", "rb") as f:
    nonduplicates = pickle.load(file=f)

In [None]:
import torch
from sklearn.model_selection import train_test_split

# Split into input (x) and output (y)
x_data = nonduplicates[:, :-1, :, :] / 255 # All but the last channel
y_data = nonduplicates[:, -1, :, :] / 255  # Only the last channel

# Convert to PyTorch tensors
x_data = torch.from_numpy(x_data).float()
y_data = torch.from_numpy(y_data).float()

# Split into training and testing sets
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.005, random_state=42)

# Print the shapes of the resulting tensors
print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)
print("x_test shape:", x_test.shape)
print("y_test shape:", y_test.shape)

batch_size = 512  # You can adjust the batch size as needed
train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [3]:
#basic feed forward test
class FullyConnectedNN(nn.Module):
    def __init__(self):
        super(FullyConnectedNN, self).__init__()
        self.flatten = nn.Flatten()  
        self.fc1 = nn.Linear(5 * 10 * 10, 200) 
        self.relu1 = nn.ReLU()  
        self.fc2 = nn.Linear(200, 200) 
        self.relu2 = nn.ReLU() 
        self.fc3 = nn.Linear(200, 200) 
        self.relu3 = nn.ReLU() 
        self.fc4 = nn.Linear(200, 10 * 10) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.relu3(x)
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x.view(-1, 10, 10)

In [None]:
num_epochs = 200
learning_rate = 0.01

# Create an instance of the model
model = FullyConnectedNN().to(device)

# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0.1, total_iters=num_epochs, verbose=True)

for epoch in range(num_epochs):
    for (inputs, labels) in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        model.train()
        optimizer.zero_grad()
        train_outputs = model(inputs)
        loss = criterion(train_outputs, labels)
        loss.backward()
        optimizer.step()
    scheduler.step()

    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        test_outputs = model(x_test)
        test_loss = criterion(test_outputs, y_test)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Test: {test_loss.item():.4f}')
    
torch.save(model.state_dict(), 'model.pth')

In [4]:
model = FullyConnectedNN()
model.load_state_dict(torch.load("model.pth"))

<All keys matched successfully>

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Nr params: {pytorch_total_params}")

In [None]:
plot_exmp = 38
with torch.no_grad():
    print("Test set example:")
    y  = y_test[plot_exmp].expand(1, y_test[plot_exmp].shape[0], y_test[plot_exmp].shape[1])
    example = torch.row_stack((x_test[plot_exmp], y))
    plot_double_trio(example)

    print("Prediction:")
    plt.figure(figsize=(6, 2))
    plt.title("Non-binarized Output")
    plt.imshow(model(x_test[plot_exmp].unsqueeze(0))[0], cmap='gray')
    plt.axis('off')
    plt.show()

    print("Most similar case in training data:")
    distance = [torch.sum(torch.abs(x_test[plot_exmp] - np.array(x_train[i]))).item() for i in range(x_train.shape[0])]
    sim_index = np.argmin(distance)
    y  = y_train[sim_index].expand(1, y_train[sim_index].shape[0], y_train[sim_index].shape[1])
    plot_double_trio(torch.row_stack((x_train[sim_index], y)))