# Branching Image Processor with Aggregating MHA Encoders

This fully functioning notebook implements an architecture for processing an image dataset which contains multiple classification tasks.

Each batch of images is passed through an ImageToEntities module which creates entitities representing convolutional output features.  The samples in batch are then split by classification task, and passed though a parallel set of MHA Encoders -- one per classification task in the dataset. This is similar to the image-handling approach in ["Relational Deep Reinforcement Learning"](https://arxiv.org/abs/1806.01830).

Each of these AggregatingMHAEncoders consists of N stacks of MHA/Normalization/Feed-Forward layers, based upon the encoder portion of the encoder/decoder architecture originally described in ["Attention is All You Need"](https://arxiv.org/abs/1706.03762).  A final layer in each AggregatingMHAEncoder -- either a Max Pooling function, or "AggegatedMHA" function -- reduces the dimensionality of each encoders output.  

The sub-batches output by the encoders is then re-combined into a single batch for a final module, with each sample in the batch concatenated with it's associated classification task, and this recombined batch is then passed to a final feed-forward layer.

For testing, this repository contains one dataset from ["An Explicitly Relational Neural Network Architecture"](https://arxiv.org/abs/1905.10307). 

Each of these architectural elements are shown in more detail in diagrams in the rest of the notebook.


<img src="images/branching_mha_encoder.png">

In [1]:
%matplotlib inline
import os
import sys
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter
from torch.multiprocessing import Pool, Process, set_start_method
import math
import pdb
import matplotlib.pylab as plt
import fnmatch
from  torch.nn.utils import clip_grad_value_
import copy
import argparse





In [2]:
class BranchingImageProcessor(nn.Module):
    

    def __init__(self, conv_net,encoder_wrappers,recombined_ff,tasks_in_data,d_model):
        super(BranchingImageProcessor, self).__init__()
        self.conv_net=conv_net
        self.d_model=d_model

        self.task_encoders = nn.ModuleList(encoder_wrappers[0:])
        self.recombined_ff=recombined_ff
        self.x_splitter_list=XSplitter.create_x_splitters(tasks_in_data)

        
    def forward(self, x,task_tensor,y_tensor):
        
        conv_input=self.conv_net(x)

        split_list,split_has_data=XSplitter.split_inputs_by_task(conv_input,task_tensor,y_tensor)
        task_encoder_index=0
        found_split_with_data=False
        for x_split in split_list:
            if split_has_data[task_encoder_index]==True:
                x_out=self.task_encoders[task_encoder_index](x_split)
                if found_split_with_data==False:
                    final_input=x_out
                    found_split_with_data=True
                else:
                    final_input=torch.cat((final_input,x_out),0) 
            task_encoder_index+=1
        new_task,new_y_tensor=XSplitter.reconstruct_task_and_y_tensors()

        final_output = self.recombined_ff(final_input,new_task)
        return final_output,new_y_tensor
    

    @classmethod
    def make_model(cls,args):
        
        c = copy.deepcopy  
        
        attn=MultiHeadedAttention(args.encoder_attention_heads,args.d_model)
        ffn=EncoderFeedForward(args.d_model, args.encoder_ffn_dim, args.dropout)
        encoder_wrappers=[]
        
        for arg_index in range(args.encoder_count):
            encoder = Encoder(EncoderLayer(args.d_model,c(attn),
                                            c(ffn),args.dropout),args.encoder_layers)
            encoder_wrappers.append(AggregatingMHAEncoder(encoder,args.d_model,args.entity_count,
                                 args.aggregate_attention_heads,args.dropout,args.aggregation_method))

        recombined_ff= RecombinedFF(args.d_model,args.final_module_hidden_size,args.tgt_class_count,args.entity_count,args.aggregation_method)
            
        model = BranchingImageProcessor(
            ImageToEntities(args),
            encoder_wrappers,
            recombined_ff,
            args.tasks_in_data,
            args.d_model)
        
 
        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        return model
 



In [3]:
class RecombinedFF(nn.Module):
    def __init__(self, d_model,final_module_hidden_size,tgt_class_count,entity_count,aggregation_method):
        super(RecombinedFF, self).__init__()
        
        if aggregation_method==AGG_METHOD_NONE:
            final_module_input_size=(d_model*entity_count)+1
        else:
            final_module_input_size=d_model+1
  
        self.proj1 = nn.Linear(final_module_input_size, final_module_hidden_size)
        self.proj2 = nn.Linear(final_module_hidden_size,tgt_class_count)

    def forward(self,x,task):

        task=task.unsqueeze(1)
        x=torch.cat((x,task),1)
        x1=self.proj1(x)
        x2=self.proj2(x1)
        return x2      


    

# XSplitter

Nice class for splitting and recombining X, Y, and Task values in a dataset.
See use in BranchingImageProcessor module above.

In [4]:
class XSplitter:
    
    # Don't initialize directly.
    # Use create_x_splitters factory which creates all XSplitters at once
    # and puts them in splitter_list
    
    x_splitter_list = []
    
    def __init__(self,select_task):   
        self.select_task=select_task
        
    # create sub-batch of x,y, and task tensors 
    # for this XSplitters's select_task
    
    def select(self,x,task_tensor,y_tensor):
        mask=[self.select_task==task_tensor]
        split=x[mask]
        self.task_split=task_tensor[mask]
        self.y_split=y_tensor[mask]
        if len(split)==0:
            return(split,False)
        else:
            return(split,True)
        
    def get_task_and_y_splits(self):
        return(self.task_split,self.y_split)
    
    # Pass list of tasks in dataset to create list of XSplitter objects.
    
    @classmethod
    def create_x_splitters(cls,task_list):
        for task in task_list:
            cls.x_splitter_list.append(XSplitter(int(task)))
            
    # Pass the batch x, task, and y values.

    @classmethod
    def split_inputs_by_task(cls,x,task_tensor,y_tensor):
        split_list=[]
        split_has_data=[]
        for x_splitter in cls.x_splitter_list:
            split,has_data=x_splitter.select(x,task_tensor,y_tensor)
            split_list.append(split)
            split_has_data.append(has_data)
        return (split_list,split_has_data)  

    # after parallel encoders are run, output x values will be grouped into contiguous blocks by task.
    # so here we reorganize y and task values into contiguous blocks as well.  See diagram at top of notebook.
    @classmethod
    def reconstruct_task_and_y_tensors(cls):
        first_split_with_data=False
        for x_splitter in cls.x_splitter_list:
            task_split,y_split=x_splitter.get_task_and_y_splits()
            if len(task_split)>0:
                if first_split_with_data==False:
                    new_task_tensor=task_split
                    new_y_tensor=y_split
                    first_split_with_data=True
                else:
                    new_task_tensor=torch.cat((new_task_tensor,task_split),0)
                    new_y_tensor=torch.cat((new_y_tensor,y_split),0)
        return (new_task_tensor,new_y_tensor)


The following module has an identical structure to the encoder half of the Transformer Architecture as initially presented in  ["Attention is All You Need"](https://arxiv.org/abs/1806.01830) https://arxiv.org/abs/1706.03762.
The difference is that the second dimension of the final output is reduced from args.entity_count to one, either via a max function, or via an AggregatingMultiHeadedAttention function.

$Aggregating Multi-Headed Encoder$

<img src="images/encoder_wrapper.png">

In [5]:
class AggregatingMHAEncoder(nn.Module):
    
    def __init__(self, encoder, d_model,entity_count,heads,dropout,aggregation_method):
        super(AggregatingMHAEncoder, self).__init__()
        
        self.encoder=encoder
        self.aggregating_multi_head_attn=AggregatingMultiHeadedAttention(heads,d_model,entity_count,dropout)
        self.aggregation_method=aggregation_method
        
        
    def forward(self,x):
        x=self.encoder(x)
        if self.aggregation_method==AGG_METHOD_NONE:
            x=x.view(x.shape[0],-1)
            return x
        if self.aggregation_method==AGG_METHOD_MAX:
            highest_entity=torch.max(x,1,keepdim=True)
            x=highest_entity[0]
            x=x.squeeze()
        else:
            x=self.aggregating_multi_head_attn(x,x,x)
        return F.relu(x)



# Encoder

The code comprising the Encoder layer was created by modifying code from the elegant Pytorch implementation in ["The Annotated Transformer"](http://nlp.seas.harvard.edu/2018/04/03/attention.html) 

In [6]:
class Encoder(nn.Module):
    "Core encoder is a stack of N layers"
    def __init__(self, layer,N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x):
        "Pass the input through each layer in turn."
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)
    

class LayerNorm(nn.Module):
    "Construct a layernorm module"
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2



class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))



class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x))
        return self.sublayer[1](x, self.feed_forward)




def attention(query, key, value, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1))/ math.sqrt(d_k)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn



class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value):
        "Implements Figure 2"
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value =             [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, 
                                 dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)



class EncoderFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(EncoderFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


    

class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
    @classmethod   
    def get_std_opt(cls,model):
        return NoamOpt(model.d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


In ["Relational Deep Reinforcement Learning"](https://arxiv.org/abs/1806.01830), which uses a similar image transformation module to produce entities for an MHA function, the MHA layers are followed by a max pooling function which reduces the dimensionality to a single entity.   
In my model, as a final step to the AggregatingMHAEncoder, I offer the max pooling if the use_max argument is true, otherwise I use the AggregatingMultiHeadedAttention function below. 
This function is identical to the MHA layer in the encoder module, except that here the entities are flattened to form the initial query, which results in a final output which is aggregated to a single entity per-batch.


<img src="images/top.png"  align="center"/>
<img src="images/attention.png"  width="200" height="80" align="center"/>
<img src="images/aggregated_attention.png"  align="center"/>

In [7]:
class AggregatingMultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, entity_count,dropout=0.1):
        "Take in model size and number of heads."
        super(AggregatingMultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.key_value_linears = clones(nn.Linear(d_model, d_model), 2)
        self.query_linear = nn.Linear(d_model*entity_count,d_model)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value):

        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        key, value =             [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.key_value_linears, (key, value))]
        
        flat_query=query.view(nbatches,-1)

        query = self.query_linear(flat_query).view(nbatches,-1,self.h,self.d_k).transpose(1,2)
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, 
                                 dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(nbatches, self.h * self.d_k)
        return x

# ImageToEntities Module

Module consists of:
 - Convolutional Layer
 - Layer which adds channels representing the x and y coordinates  of the convolutional output features, as described in ["An Intriguing Failing of Convolutional Neural Networks"](https://arxiv.org/pdf/1807.03247.pdf)   
 using implementation borrowed from ["CoordConv-Pytorch"](https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py).
 - A reshape/permutation step which creates a set of entities for MultiHeadedAttention processing.  This transformation and depiction below are similar to ["Relational Deep Reinforcement Learning"](https://arxiv.org/abs/1806.01830).


<img src="images/image_to_entities.png">

In [8]:
class ImageToEntities(nn.Module):
    def __init__(self,args,):
        super(ImageToEntities, self).__init__()
        self.x_dim=args.x_dim
        self.y_dim=args.y_dim
        self.out_channels=args.out_channels
        self.in_channels=args.in_channels
        self.add_output_coords=args.add_output_coords
        self.dropout = nn.Dropout(args.conv_dropout)
        
        if self.add_output_coords:
            self.out_channels-=2
        self.conv = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=12, stride=6, padding=0)
        
        if self.add_output_coords:
            #could make this dynamic rather than hard-coding 5 by passing tmp value to convnet and testing len of features
            self.coord_output_adder = AddCoordsTh(x_dim=5, y_dim=5, with_r=args.with_r)
        
    def forward(self, x):
        x = self.conv(x)
        if self.add_output_coords:
            x = self.coord_output_adder(x)
        x1=self.make_entities(x)
        x1 = self.dropout(x1)
        x2 = F.relu(x1)
        return x2
 
    
    def make_entities(self,x):
        x=x.reshape(x.shape[0],x.shape[1],-1)
        x=x.permute(0,2,1)
        return x


class AddCoordsTh(nn.Module):
    def __init__(self, x_dim=64, y_dim=64, with_r=False):
        super(AddCoordsTh, self).__init__()
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.with_r = with_r

    def forward(self, input_tensor):
        """
        input_tensor: (batch, c, x_dim, y_dim)
        """
        batch_size_tensor = input_tensor.shape[0]

        xx_ones = torch.ones([1, self.y_dim], dtype=torch.float32)
        xx_ones = xx_ones.unsqueeze(-1)

        xx_range = torch.arange(self.x_dim, dtype=torch.float32).unsqueeze(0)
        xx_range = xx_range.unsqueeze(1)

        xx_channel = torch.matmul(xx_ones, xx_range)
        xx_channel = xx_channel.unsqueeze(-1)

        yy_ones = torch.ones([1, self.x_dim], dtype=torch.float32)
        yy_ones = yy_ones.unsqueeze(1)

        yy_range = torch.arange(self.y_dim, dtype=torch.float32).unsqueeze(0)
        yy_range = yy_range.unsqueeze(-1)

        yy_channel = torch.matmul(yy_range, yy_ones)
        yy_channel = yy_channel.unsqueeze(-1)
        
        xx_channel = xx_channel.permute(0, 3, 2, 1)
        yy_channel = yy_channel.permute(0, 3, 2, 1)

        xx_channel = xx_channel.float() / (self.x_dim - 1)
        yy_channel = yy_channel.float() / (self.y_dim - 1)

        #xx_channel = xx_channel * 2 - 1
        #yy_channel = yy_channel * 2 - 1

        xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
        yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)

        ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
        
        if self.with_r:
            rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
            ret = torch.cat([ret, rr], dim=1)

        return ret




# Params

Using argparse, to ease port to command-line version.   See hack at end of main to make this work with Jupyter.


In [9]:
AGG_METHOD_NONE=0
AGG_METHOD_MAX=1
AGG_METHOD_MHA=2

class Arguments():
    
    def __init__(self):
        self.parser = argparse.ArgumentParser(description='train')
        self.add_arguments()
        self.args = self.parser.parse_args()

    def add_arguments(self): 
        self.add_conv_arguments()
        self.add_dataset_arguments()
        self.add_training_arguments()
        self.add_encoder_arguments()
        
    def add_encoder_arguments(self):
        # Normally encoder-count equals number of tasks in data 
        # But could be extra dormant ones to maintain model architecture across model save/restore
        # with different datasets
        self.parser.add_argument('--encoder-count', type=int,metavar='N',
                                 default=3,help='number of parallel encoders')
        self.parser.add_argument('--tasks-in-data', nargs='+', default=['0','1','2'])     
        self.parser.add_argument('--entity_count', default=25, type=int, metavar='N',
            help='number of entities')
        self.parser.add_argument('--aggregation-method',type=int, metavar='N',
                            default=AGG_METHOD_NONE,help='aggregation method for encoder output')
        self.parser.add_argument('--dropout', type=float, metavar='D',
                            default=0.1,help='dropout probability')    
        self.parser.add_argument('--d-model', type=int, metavar='N',
                            default=32,help='encoder d_model')
        self.parser.add_argument('--encoder-layers', type=int, metavar='N',
                            default=3,help='num encoder layers')
        self.parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
                            default=16,help='num encoder attention heads')
        self.parser.add_argument('--attention-dropout', type=float, metavar='D',
                            default=0.1,help='dropout probability for attention weights')
        self.parser.add_argument('--encoder-ffn-dim', type=int, metavar='N',
                            default=256,help='encoder dimension for FFN') 
        self.parser.add_argument('--tgt_class_count', type=int, metavar='N',
                            default=2,help='tgt class count') 
        self.parser.add_argument('--final-module-hidden-size', type=int, metavar='N',
                            default=1024,help='generator_hidden_size')                
        self.parser.add_argument('--aggregate-attention-heads', type=int, metavar='N',
                            default=8,help='number of heads for AggregatingMHAAttention step')
      
                        
    def add_training_arguments(self):
        self.parser.add_argument('--do-restore',default=False,action='store_true',help="do restore")
        self.parser.add_argument('--do-train',default=True,action='store_false',help="do train") 
        self.parser.add_argument('--freeze-lower-layers',default=False,action='store_true',help="freeze lower layers")
        self.parser.add_argument('--check-interval', type=int, metavar='N',
                                 default=100,help='check interval')
        self.parser.add_argument('--base-path', metavar='N',
                                 default="/Users/azulay/jupyter_work/",help='base path') 
        self.parser.add_argument('--model-path', metavar='N',
                                 default="model_data/saved_parms",help='model path') 
        self.parser.add_argument('--tensorboard-path', metavar='N',
                                 default="rt_logs",help='tensorflow path')         
        self.parser.add_argument('--batch-size', type=int, metavar='N',
                                 default=100,help='batch size') 
        self.parser.add_argument('--test-batch-size', type=int, metavar='N',
                                 default=200,help='batch size') 
        self.parser.add_argument('--max-epochs', type=int, metavar='N',
                                 default=20,help='max epochs')
        self.parser.add_argument('--default-device', metavar='N',
                                 default="cpu",help='default device')     
        self.parser.add_argument('--ignore-saved-counts',default=True,action='store_false',help="do train") 
        
        
        
    def add_dataset_arguments(self):
        self.parser.add_argument('--data-dir', metavar='N',
                                 default="datasets/3task_col_patts_pentos/",help='data dir')      
        self.parser.add_argument('--test-dir', metavar='N',
                                 default="datasets/3task_col_patts_stripes/",help='valid data dir')      
        self.parser.add_argument('--report-size', type=int, metavar='N',
                                 default=2,help='report size')  
        self.parser.add_argument('--report-label-offset', type=int, metavar='N',
                                 default=0,help='report label offset') 
        self.parser.add_argument('--report-task-offset', type=int, metavar='N',
                                 default=1,help='report task offset') 

        
    def add_conv_arguments(self):
        self.parser.add_argument('--y-dim', type=int, metavar='N',
                                 default=36,help='size of y_dimention')
        self.parser.add_argument('--x-dim', type=int, metavar='N',
                                 default=36,help='size of x_dimention')
        self.parser.add_argument('--in-channels', type=int, metavar='N',
                                 default=3,help='# input channels')
        self.parser.add_argument('--out-channels', type=int, metavar='N',
                                 default=32,help='# output channels')
        self.parser.add_argument('--kernel-size', type=int, metavar='N',
                                 default=12,help='kernel size')
        self.parser.add_argument('--stride', type=int, metavar='N',
                                 default=6,help='stride')
        self.parser.add_argument('--add-output-coords', action='store_false',
                                 default=True,help='add X and Y coordinate channels')
        self.parser.add_argument('--with-r', action='store_false',
                                 default=False,help='with r')
        self.parser.add_argument('--conv-dropout', type=float, metavar='D',
                            default=0.0,help='dropout probability')    
        

In [10]:
class Train:
    
    def __init__(self,stack,stack_optimizer,criterion,args):
        self.stack=stack
        self.stack_optimizer=stack_optimizer
        self.criterion=criterion
        self.args=args
 
    def train(self,input_tensor, y_tensor,task_tensor): 
        
        if args.freeze_lower_layers:
            self.stack.eval()
            self.stack.generator.train()
            for task_encoder in self.stack.task_encoders:
                task_encoder.train()
        else:
            self.stack.train()
                
        self.stack_optimizer.optimizer.zero_grad()    

        stack_output,y_tensor=self.stack(input_tensor,task_tensor,y_tensor)

        loss = self.criterion(stack_output, y_tensor)
        loss.backward()
 
        self.stack_optimizer.step()
        _, predicted = torch.max(stack_output, 1)
        correct = (predicted == y_tensor).sum()
        accuracy = 100 * correct / y_tensor.size(0)
        return loss,accuracy

    
    def eval(self,input_tensor,y_tensor,task_tensor):
        self.stack.eval()
        stack_output,y_tensor=self.stack(input_tensor,task_tensor,y_tensor)
        loss = self.criterion(stack_output, y_tensor)
        _, predicted = torch.max(stack_output, 1)
        correct = (predicted == y_tensor).sum()
        accuracy = 100 * correct / y_tensor.size(0)
        return loss,accuracy
    
    def restore(self,model_dir):
        checkpoint = torch.load(model_dir)
        self.stack.load_state_dict(checkpoint['stack_state_dict'])
        self.stack_optimizer.optimizer.load_state_dict(checkpoint['stack_optimizer_state_dict'])
        lowest_mean_loss=checkpoint['loss']
        lowest_test_loss=checkpoint['test_loss']
        epoch=checkpoint['epoch']
        batch_count=checkpoint['batch_count']
        print("restoring")
        print("saved epoch was ",epoch,"batch count was",batch_count)
        return lowest_mean_loss,lowest_test_loss,epoch,batch_count
        

    def save(self,mean_loss,test_loss,epoch,batch_count,model_dir):
        torch.save({
            'stack_state_dict': self.stack.state_dict(),
            'stack_optimizer_state_dict': self.stack_optimizer.optimizer.state_dict(),                         
            'loss': mean_loss,
            'test_loss' : test_loss,
            'epoch' : epoch,
            'batch_count' : batch_count,
             }, model_dir)   
        

In [11]:
def main(args):

    tensorboard_path=os.path.join(args.base_path,args.tensorboard_path)
    data_dir=os.path.join(args.base_path,args.data_dir)
    test_dir=os.path.join(args.base_path,args.test_dir) 
    model_path=os.path.join(args.base_path,args.model_path)
    external_code_path=os.path.join(args.base_path,"external_code")
    
    # Dataset has to live outside this notebook in order to use set_start_method('spawn') in main
    # Otherwise notebook hangs when using GPU.  Not sure why.  Will investigate at some point.
    if external_code_path not in sys.path:
        sys.path.append(external_code_path)
        
    from multi_task_dataset import MultiTaskDataset
    
    writer = SummaryWriter(comment="thing grid",log_dir=tensorboard_path)

    lowest_mean_loss=999999
    lowest_test_loss=999999
    total_loss=0   
    total_accuracy=0
    epoch=0 
    batch_count=0
    check_count=0
    data_files = os.listdir(data_dir) 
    
    total_data_files = len(data_files)
    batches_per_epoch=total_data_files//args.batch_size
    print("batches per epoch:",batches_per_epoch)

    if args.default_device=="cuda":
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    
    stack=BranchingImageProcessor.make_model(args)

    if args.default_device=="cuda":
        stack=stack.cuda()
    
    stack_optimizer=NoamOpt.get_std_opt(stack)
    criterion = nn.CrossEntropyLoss()   

    train=Train(stack,stack_optimizer,criterion,args)
    
    if args.do_train==False:
        lowest_mean_loss,lowest_test_loss,epoch=train.restore(model_path)
        test_dataset = MultiTaskDataset(test_dir)
            
        test_dataloader = data.DataLoader(test_dataset, batch_size=args.test_batch_size,
                                shuffle=True,num_workers=1)    
        
        for v_batch, sample_v_batched in enumerate(test_dataloader):
            v_image = sample_v_batched['image'].to(torch.device(args.default_device))
            v_label = sample_v_batched['label'].to(torch.device(args.default_device)) 
            v_task = sample_v_batched['task'].to(torch.device(args.default_device))
            test_loss,test_accuracy=train.eval(v_image,v_label,v_task)
            print("Test loss",test_loss.item(),"test accuracy",test_accuracy.item())
        
    else:

        if args.do_restore:
            lowest_mean_loss,lowest_test_loss,epoch,batch_count=train.restore(model_path)
            if args.ignore_saved_counts==True:
                print("zeroing saved counts")
                lowest_mean_loss=999999
                lowest_test_loss=999999
                epoch=0 
                batch_count=0

        if args.freeze_lower_layers:
            print("!!***** FREEZING PARMS *********!!")
            for param in stack.parameters():
                param.requires_grad = False
            for param in stack.generator.parameters():
                param.requires_grad = True

        for epoch in range(epoch,args.max_epochs):

            train_dataset = MultiTaskDataset(args,data_dir)
            test_dataset = MultiTaskDataset(args,test_dir)

            dataloader = data.DataLoader(train_dataset, batch_size=args.batch_size,
                                    shuffle=True,num_workers=1)

            test_dataloader = data.DataLoader(test_dataset, batch_size=args.test_batch_size,
                                    shuffle=True,num_workers=1)
 
            for i_batch, sample_batched in enumerate(dataloader):

                batch_count+=1
                image = sample_batched['image'].to(torch.device(args.default_device))
                label = sample_batched['label'].to(torch.device(args.default_device))
                task = sample_batched['task'].to(torch.device(args.default_device))

                loss,accuracy=train.train(image,label,task)
                check_count+=1
                total_loss+=loss.item()
                total_accuracy+=accuracy.item()
                
                if batch_count % args.check_interval==0:
                    mean_loss=total_loss/check_count
                    mean_accuracy=total_accuracy//check_count
                    writer.add_scalar("loss", mean_loss, batch_count)     
                    writer.add_scalar("accuracy",accuracy,batch_count)     
                    print("epoch",epoch,"count ",batch_count,"loss",mean_loss, "lowest loss",lowest_mean_loss,"accuracy",mean_accuracy)
                    sample_v_batched = iter(test_dataloader).next()
                    v_image = sample_v_batched['image'].to(torch.device(args.default_device))
                    v_label = sample_v_batched['label'].to(torch.device(args.default_device)) 
                    v_task = sample_v_batched['task'].to(torch.device(args.default_device))
                    test_loss,test_accuracy=train.eval(v_image,v_label,v_task)
                    test_loss=test_loss.item()
                    writer.add_scalar("test loss", test_loss, batch_count)
                    writer.add_scalar("test accuracy",test_accuracy,batch_count)
                    print("epoch",epoch,"test loss",test_loss, "lowest test loss",lowest_test_loss,"test accuracy",test_accuracy)
                    if test_loss<lowest_test_loss:
                        lowest_test_loss=test_loss
                        train.save(lowest_mean_loss,lowest_test_loss,epoch,batch_count,model_path)
                    if mean_loss<lowest_mean_loss:
                        lowest_mean_loss=mean_loss
                    total_loss=0
                    total_accuracy=0
                    check_count=0
                if batch_count % batches_per_epoch ==0:
                        break



if __name__ == "__main__":


    try:
        set_start_method('spawn',force=True)
    except RuntimeError:
        print("bad")

    
    print(sys.argv[0])
    sys.argv=[sys.argv[0]]
    arguments=Arguments()
    args=arguments.args
    print(args)
    main(args)

    

/Users/azulay/opt/anaconda3/envs/new/lib/python3.6/site-packages/ipykernel_launcher.py
Namespace(add_output_coords=True, aggregate_attention_heads=8, aggregation_method=0, attention_dropout=0.1, base_path='/Users/azulay/jupyter_work/', batch_size=100, check_interval=100, conv_dropout=0.0, d_model=32, data_dir='datasets/3task_col_patts_pentos/', default_device='cpu', do_restore=False, do_train=True, dropout=0.1, encoder_attention_heads=16, encoder_count=3, encoder_ffn_dim=256, encoder_layers=3, entity_count=25, final_module_hidden_size=1024, freeze_lower_layers=False, ignore_saved_counts=True, in_channels=3, kernel_size=12, max_epochs=20, model_path='model_data/saved_parms', out_channels=32, report_label_offset=0, report_size=2, report_task_offset=1, stride=6, tasks_in_data=['0', '1', '2'], tensorboard_path='rt_logs', test_batch_size=200, test_dir='datasets/3task_col_patts_stripes/', tgt_class_count=2, with_r=False, x_dim=36, y_dim=36)
got it /Users/azulay/jupyter_work/external_code
bat

KeyboardInterrupt: 