In [1]:
!pip install git+https://github.com/nivasini/EGG.git

Collecting git+https://github.com/nivasini/EGG.git
  Cloning https://github.com/nivasini/EGG.git to /tmp/pip-req-build-i3h76ek9
  Running command git clone --filter=blob:none --quiet https://github.com/nivasini/EGG.git /tmp/pip-req-build-i3h76ek9
  Resolved https://github.com/nivasini/EGG.git to commit f36d123af22eb0d127d2089c993b4eff8314a43d
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting editdistance (from EGG==0.1.0)
  Downloading editdistance-0.8.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Collecting dataclasses (from EGG==0.1.0)
  Downloading dataclasses-0.6-py3-none-any.whl.metadata (3.0 kB)
Collecting wandb (from EGG==0.1.0)
  Downloading wandb-0.19.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting submitit (from EGG==0.1.0)
  Downloading submitit-1.5.2-py3-none-any.whl.metadata (7.9 kB)
Collecting docker-pycreds>=0.4.0 (from wandb->EGG==0.1.0)
  Downloading docker_pycreds-0.4.0-py2.py3-none-a

In [17]:
%matplotlib inline
%load_ext autoreload
%autoreload 2


import torch
import torch.nn as nn
import egg.core as core

from torchvision import datasets, transforms
from torch import nn
from torch.nn import functional as F

import matplotlib.pyplot as plt
import random
import numpy as np
import random

from pylab import rcParams
rcParams['figure.figsize'] = 5, 10

# For convenince and reproducibility, we set some EGG-level command line arguments here
opts = core.init(params=[
                        # '--random_seed=7', # will initialize numpy, torch, and python RNGs
                         '--lr=1e-3',   # sets the learning rate for the selected optimizer
                         '--batch_size=32',
                         '--optimizer=adam'])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
transform = transforms.ToTensor()

batch_size = opts.batch_size # set via the CL arguments above
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
           transform=transform),
           batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transform),
           batch_size=batch_size, shuffle=False, **kwargs)



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
print("Batch size: ", batch_size)
print ("Number of batches in train_loader: ", len(train_loader))
print ("Number of batches in test_loader: ", len(test_loader))


Batch size:  32
Number of batches in train_loader:  1875
Number of batches in test_loader:  313


Pretraining

In [3]:
class Vision(nn.Module):
    def __init__(self):
        super(Vision, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        return x

class PretrainNet(nn.Module):
    def __init__(self, vision_module):
        super(PretrainNet, self).__init__()
        self.vision_module = vision_module
        self.fc = nn.Linear(500, 10)

    def forward(self, x):
        x = self.vision_module(x)
        x = self.fc(F.leaky_relu(x))
        return x

def pretrain(num_epochs):
  vision = Vision()
  class_prediction = PretrainNet(vision) #  note that we pass vision - which we want to pretrain
  optimizer = core.build_optimizer(class_prediction.parameters()) #  uses command-line parameters we passed to core.init
  class_prediction = class_prediction.to(device)

  for epoch in range(num_epochs):
    mean_loss, n_batches = 0, 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = class_prediction(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        mean_loss += loss.mean().item()
        n_batches += 1

    print(f'Train Epoch: {epoch}, mean loss: {mean_loss / n_batches}')

  return vision


In [4]:
class Sender(nn.Module):
    def __init__(self, vision, output_size):
        super(Sender, self).__init__()
        self.fc = nn.Linear(500, output_size)
        self.vision = vision

    def forward(self, x, aux_input=None):
        with torch.no_grad():
            x = self.vision(x)
        x = self.fc(x)
        return x


class Receiver(nn.Module):
    def __init__(self, input_size):
        super(Receiver, self).__init__()
        self.fc = nn.Linear(input_size, 784)

    def forward(self, channel_input, receiver_input=None, aux_input=None):
        x = self.fc(channel_input)
        return torch.sigmoid(x)


def loss(sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input=None):
    loss = F.binary_cross_entropy(receiver_output, sender_input.view(-1, 784), reduction='none').mean(dim=1)
    return loss, {}

Evaluation with test set

In [43]:
def plot_receiver_support(game, vocab_size):
    game.eval()

    for z in range(vocab_size):
        t = torch.zeros(vocab_size).to(device)
        t[z] = 1
        with torch.no_grad():
            # Receiver outputs a single tensor of predictions
            sample = game.receiver(t).float().cpu()
        sample = sample.view(28, 28)
        plt.title(f"Input: symbol {z}")
        plt.imshow(sample, cmap='gray')
        plt.show()

def size_dataset(dataset):
  return len(dataset[0][0])

def get_random_test_dataset(n):
    """
    Generate a test dataset with a random set of n images from test_loader.dataset.

    Args:
        test_loader: The data loader containing the dataset.
        n: The number of random images to select.

    Returns:
        test_dataset: A list containing a single tensor with the randomly selected images and None as a placeholder for labels.
    """
    dataset = test_loader.dataset

    # Ensure there are enough samples in the dataset
    if n > len(dataset):
        raise ValueError(f"Requested {n} images, but the dataset only contains {len(dataset)} samples.")

    # Randomly sample `n` indices from the dataset
    indices = random.sample(range(len(dataset)), n)

    # Load the images corresponding to the selected indices
    test_inputs = []
    for idx in indices:
        img, _ = dataset[idx]  # Retrieve the image; ignore the label
        test_inputs.append(img.unsqueeze(0))  # Add a batch dimension

    # Combine all selected images into a single tensor
    test_inputs = torch.cat(test_inputs)

    # Wrap into the expected structure with None as a placeholder
    test_dataset = [[test_inputs, None]]
    return test_dataset


def get_category_test_dataset():
    test_inputs = []
    for z in range(10):
        index = (test_loader.dataset.targets[:100] == z).nonzero()[1, 0]
        img, _ = test_loader.dataset[index]
        test_inputs.append(img.unsqueeze(0))
    test_inputs = torch.cat(test_inputs)

    test_dataset = [[test_inputs, None]]
    return test_dataset


def test_loss(game, test_dataset, is_gs, variable_length):
    test_size = size_dataset(test_dataset)
    interaction = \
            core.dump_interactions(game, test_dataset, is_gs, variable_length)
    total_loss = 0.0
    for z in range(test_size):
        loss = F.binary_cross_entropy(interaction.sender_input[z].view(-1, 784).squeeze(0),
                                      interaction.receiver_output[z]).item()
        total_loss += loss
    return total_loss / test_size

def plot(game, test_dataset, is_gs, variable_length):
    interaction = \
            core.dump_interactions(game, test_dataset, is_gs, variable_length)

    for z in range(size_dataset(test_dataset)):
        src = interaction.sender_input[z].squeeze(0)
        dst = interaction.receiver_output[z].view(28, 28)
        # we'll plot two images side-by-side: the original (left) and the reconstruction
        image = torch.cat([src, dst], dim=1).cpu().numpy()

        plt.title(f"Input: digit {z}, channel message {interaction.message[z]}")
        plt.imshow(image, cmap='gray')
        plt.show()

def plot_test_performance(game, is_gs, variable_length):
  test_dataset = get_category_test_dataset()
  plot(game, test_dataset, is_gs, variable_length)



In [6]:
import inspect

print(inspect.getfile(core.GumbelSoftmaxWrapper))

/usr/local/lib/python3.11/dist-packages/egg/core/gs_wrappers.py


In [7]:
def test_optimality_encoder(game, test_set_size, vocab_size = 10):


def test_alignment(game, test_set_size, vocab_size = 10):

IndentationError: expected an indented block after function definition on line 1 (<ipython-input-7-cb836481d7e1>, line 4)

In [19]:
def gs_pipeline(num_pretrain_epochs=10,
                vocab_size=10,
                n_epochs=15,
                sender_update_freq=1,
                sender_temp_decay=0.9,
                vision=None
                ):
    if vision is None:
        vision = pretrain(num_pretrain_epochs)
    sender = Sender(vision, vocab_size)
    sender = core.GumbelSoftmaxWrapper(sender, temperature=1.0) # wrapping into a GS interface, requires GS temperature
    receiver = Receiver(input_size=400)
    receiver = core.SymbolReceiverWrapper(receiver, vocab_size, agent_input_size=400)

    game = core.SymbolGameGS(sender, receiver, loss, sender_update_freq=sender_update_freq)
    optimizer = torch.optim.Adam(game.parameters())

    trainer = core.Trainer(
        game=game, optimizer=optimizer, train_data=train_loader,
        validation_data=test_loader, callbacks=[core.ConsoleLogger(as_json=True,
                                                     print_train_loss=True),
                                                core.TemperatureUpdater(agent=sender, decay=sender_temp_decay, minimum=0.1)]
    )

    trainer.train(n_epochs)

    random_test_dataset = get_random_test_dataset(100)
    print(test_loss(game, random_test_dataset, is_gs=True, variable_length=False))

    return game

In [20]:
vision = pretrain(15)

Train Epoch: 0, mean loss: 0.12254741532768433
Train Epoch: 1, mean loss: 0.04059718344427723
Train Epoch: 2, mean loss: 0.02874404519118058
Train Epoch: 3, mean loss: 0.020632181487870913
Train Epoch: 4, mean loss: 0.014911599768842765
Train Epoch: 5, mean loss: 0.01314719554741091
Train Epoch: 6, mean loss: 0.011584931882929618
Train Epoch: 7, mean loss: 0.00946955801208773
Train Epoch: 8, mean loss: 0.007430406021851075
Train Epoch: 9, mean loss: 0.007094549284744077
Train Epoch: 10, mean loss: 0.007755614051332237
Train Epoch: 11, mean loss: 0.006623633634604119
Train Epoch: 12, mean loss: 0.004826118141988124
Train Epoch: 13, mean loss: 0.00609306494914883
Train Epoch: 14, mean loss: 0.00459140554880463


In [None]:
list_sender_update_freq = [1, 1000]
games = []
for d in list_sender_update_freq:
    game = gs_pipeline(vision=vision,
                       sender_update_freq=d,
                       n_epochs=2000,
                       vocab_size=10)
    games.append(game)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
{"loss": 0.21203476190567017, "mode": "test", "epoch": 118}
{"loss": 0.21228298544883728, "mode": "train", "epoch": 119}
{"loss": 0.21196091175079346, "mode": "test", "epoch": 119}
{"loss": 0.21244627237319946, "mode": "train", "epoch": 120}
{"loss": 0.21207328140735626, "mode": "test", "epoch": 120}
{"loss": 0.21244430541992188, "mode": "train", "epoch": 121}
{"loss": 0.21218645572662354, "mode": "test", "epoch": 121}
{"loss": 0.2124057561159134, "mode": "train", "epoch": 122}
{"loss": 0.21205955743789673, "mode": "test", "epoch": 122}
{"loss": 0.21239729225635529, "mode": "train", "epoch": 123}
{"loss": 0.21194370090961456, "mode": "test", "epoch": 123}
{"loss": 0.2124321460723877, "mode": "train", "epoch": 124}
{"loss": 0.2122977077960968, "mode": "test", "epoch": 124}
{"loss": 0.21253129839897156, "mode": "train", "epoch": 125}
{"loss": 0.21214431524276733, "mode": "test", "epoch": 125}
{"loss": 0.2123943269252777, "m

In [None]:
plot_receiver_support(games[0], 10)

In [None]:
plot_receiver_support(games[1], 10)

In [None]:
test_category_dataset = get_category_test_dataset()
plot(games[0], test_category_dataset, is_gs=True, variable_length=False)

In [None]:
test_category_dataset = get_category_test_dataset()
plot(games[1], test_category_dataset, is_gs=True, variable_length=False)