# Disjoint-domain network

### Ethan Blackwood
### October 23, 2020

**Goal**: Train and analyze the network in Rogers/McClelland 2008 with 4 disjoint domains (Figures R3-R5), which learns to extract the feature of being more or less similar to other items in the same domain, across the 4 domains which have no items, contexts or attributes in common.

In [31]:
%matplotlib widget
%config IPCompleter.greedy=True

import itertools
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from scipy.cluster import hierarchy
from sklearn.manifold import MDS, TSNE
from sklearn.decomposition import PCA

import disjoint_domain

Get inputs and outputs with particular similarity structure

In [2]:
rng = np.random.default_rng()

In [37]:
# can afford to use doubles for this
torch.set_default_tensor_type(torch.DoubleTensor)

n_domains = 4
n_contexts = 4
attrs_per_context = 50

domains = [chr(ord('A') + d) for d in range(n_domains)]
item_mat, context_mat, attr_mat = disjoint_domain.make_io_mats(
    n_contexts=n_contexts, attrs_per_context=attrs_per_context, n_domains=n_domains, rng=rng)

x_item = torch.tensor(item_mat, dtype=torch.double)
x_context = torch.tensor(context_mat, dtype=torch.double)
y = torch.tensor(attr_mat, dtype=torch.double)

# Make some variables for individual inputs for convenience later
items = np.eye(disjoint_domain.N_ITEMS * n_domains, dtype=np.float64)
def itemgroup(n):
    if n < 4:
        return '()'
    elif n < 6:
        return '[]'
    else:
        return '{}'

item_names = [d + str(n+1) + itemgroup(n) for d in domains for n in range(disjoint_domain.N_ITEMS)]

contexts = np.eye(n_contexts * n_domains, dtype=np.float64)
context_names = [d + str(n+1) for d in domains for n in range(n_contexts)]

Now build the network and training function.

In [16]:
class DisjointDomainNet(nn.Module):
    def __init__(self, n_contexts, attrs_per_context, n_domains):
        super(DisjointDomainNet, self).__init__()
        
        self.n_items = disjoint_domain.N_ITEMS * n_domains
        self.n_contexts = n_contexts * n_domains
        self.n_attributes = attrs_per_context * self.n_contexts 
        
        # Not sure if these are reasonable
        item_rep_size = self.n_items
        ctx_rep_size = self.n_contexts
        hidden_size = item_rep_size * 2
        
        # define layers
        self.item_to_irep = nn.Linear(self.n_items, item_rep_size)
        self.ctx_to_crep = nn.Linear(self.n_contexts, ctx_rep_size)
        self.irep_to_hidden = nn.Linear(item_rep_size, hidden_size)
        self.crep_to_hidden = nn.Linear(ctx_rep_size, hidden_size, bias=False) # only need 1 hidden layer bias
        self.hidden_to_attr = nn.Linear(hidden_size, self.n_attributes)
        
        # make weights start small
        with torch.no_grad():
            for p in self.parameters():
                nn.init.normal_(p.data, std=0.01)
                #nn.init.uniform_(p.data, a=-0.01, b=0.01)
                
    def forward(self, item, context):
        irep = torch.sigmoid(self.item_to_irep(item))
        crep = torch.sigmoid(self.ctx_to_crep(context))
        hidden = torch.sigmoid(self.irep_to_hidden(irep) + self.crep_to_hidden(crep))
        attr = torch.sigmoid(self.hidden_to_attr(hidden))
        return attr

In [24]:
def train_network(net, optimizer, num_epochs=200, snap_freq=20, batch_size=4, scheduler=None):
    
    n_snaps = num_epochs // snap_freq
    n_items = net.n_items
    n_rep = net.item_to_irep.out_features
    
    # Holds snapshots of input representation layer after probing with each item
    rep_snapshots = np.ndarray((n_snaps, n_items, n_rep))
    
    criterion = nn.MSELoss()
    
    n_items = len(y)
    n_batches = (n_items-1) // batch_size + 1
    
    for epoch in range(num_epochs):
        # collect snapshot
        if epoch % snap_freq == 0:
            k_snap = epoch // snap_freq
            
            with torch.no_grad():
                for k_item, item in enumerate(items):
                    act = torch.sigmoid(net.item_to_irep(torch.tensor(item)))
                    rep_snapshots[k_snap, k_item, :] = act

        running_loss = 0.0
        running_accuracy = 0.0

        order = rng.permutation(n_items)
        for k_batch in range(n_batches):
            # train
            batch_inds = order[k_batch*batch_size:(k_batch+1)*batch_size] 
            
            outputs = net(x_item[batch_inds], x_context[batch_inds])
            loss = criterion(outputs, y[batch_inds])
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            with torch.no_grad():
                running_loss += loss.item() * len(batch_inds)
                accuracy = torch.mean(((outputs > 0.5).to(torch.double) == y[batch_inds]).to(torch.double))
                running_accuracy += accuracy.item() * len(batch_inds)
        
        if epoch % snap_freq == 0:
            print(f'Epoch {epoch} end: mean loss = {running_loss / n_items:.3f}, mean accuracy = {running_accuracy / n_items:.3f}')
            
        if scheduler is not None:
            scheduler.step()
        
    return rep_snapshots

Moment of truth, time to run it

In [25]:
net = DisjointDomainNet(n_contexts, attrs_per_context, n_domains)
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)

rep_snapshots = train_network(net, optimizer, batch_size=4, snap_freq=1000, num_epochs=30000, scheduler=scheduler)

Epoch 0 end: mean loss = 0.241, mean accuracy = 0.771
Epoch 1000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 2000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 3000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 4000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 5000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 6000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 7000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 8000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 9000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 10000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 11000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 12000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 13000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 14000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 15000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 16000 end: mean loss = 0.030, mean accuracy = 0.969
Epoch 17000 end: mean loss 

In [38]:
z = hierarchy.linkage(rep_snapshots[-1])
plt.figure()
hierarchy.dendrogram(z, labels=item_names)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …