In [1]:
import torch
import torch.nn as nn
import torch.functional as F

import pyro
import pyro.distributions as dist
import pyro.infer
import pyro.optim

In [94]:
import os
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 2000

In [15]:
import numpy as np
import math

# Concept

In [16]:
class Concept:
    
    def __init__(self, mult, params, max_num):
        """
        :param mult: boolean flag that is True if concept is multiple, False if concept is range
        :param params: list of params - [mult] if multiple concept, [start, end] if range concept
        """
        self.mult = mult
        self.params = params
        self.max_num = max_num
        self._numbers = self._populate_numbers()
        self._probs = np.zeros(max_num + 1)
        self._probs[self._numbers] = 1
    
    
    def _populate_numbers(self):
        """
        Fill in the number set for the concept
        """
        if self.mult:
            if (self.params[0] < 0) or (self.params[0] > self.max_num):
                raise Exception('Invalid multiple for concept.')
            
            if not isinstance(self.params[0], int):
                raise Exception('Multiple must be an integer.')
                
            numbers = np.arange(self.params[0], self.max_num + 1, self.params[0])
            return numbers
        else:
            start, end = self.params[0], self.params[1]
            
            if (start < 0) or (end > self.max_num) or (start > end):
                raise Exception('Invalid interval for concept.')
            
            if not (isinstance(start, int) and isinstance(end, int)):
                raise Exception('Start and end must be integers.')
                
            numbers = np.arange(start, end+1)
            return numbers
        
        
    def sample(self, num_samples):
        """
        Sample num_samples observations from the concept
        """
        samples = []

        for i in range(num_samples):
            sample = random.choice(self.numbers)
            samples.append(sample)

        return samples
    
    
    def __repr__(self):
        if self.mult:
            return 'Multiples of {}: {}'.format(self.params[0], self.numbers)
        else:
            return 'Range [{}, {}]: {}'.format(self.params[0], self.params[1], self.numbers)
    
    @property
    def numbers(self):
        return self._numbers

    @property
    def probs(self):
        return self._probs
    
    def __len__(self):
        return self._numbers.shape[0]

# Generative model

In [127]:
class NumberGame:
    
    def __init__(self, max_num, mult_range_start, mult_range_end, interval_range_length):
        self.max_num = max_num
        self.mult_range_start = mult_range_start
        self.mult_range_end = mult_range_end
        self.interval_range_length = interval_range_length
        self.concepts, self.concept_probs = self.generate_all_concepts(20)
        
    
    def _generate_multiple_concept(self, multiple):
        """
        Generates a set of all multiples of multiple from [1,... , 100]
        :param multiple: integer from [1, 100]
        """
        mult_concept = Concept(True, [multiple], self.max_num)
        return mult_concept
    
    
    def _generate_range_concept(self, start, end):
        """
        Generates a set of numbers from [start, ..., end] inclusive
        """
        range_concept = Concept(False, [start, end], self.max_num)
        return range_concept
    
    
    def generate_all_concepts(self, mult_weight):
        """
        Generate a set of all concepts as specified by parameters
        :param mult_range_start: integer with smallest mult considered
        :param mult_range_end: integer with largest mult considered
        :param interval_range_length: length of ranges being considered
        """
        concepts = []
        
        for multiple in range(self.mult_range_start, self.mult_range_end + 1):
            mult_concept = self._generate_multiple_concept(multiple)
            concepts.append(mult_concept)
            
        i = 1
        while i + self.interval_range_length <= self.max_num:
            range_concept = self._generate_range_concept(i, i + self.interval_range_length)
            concepts.append(range_concept)
            i += 1
        
        concept_probs = np.ones(len(concepts)) + np.concatenate((mult_weight * np.ones(self.mult_range_end - self.mult_range_start), np.zeros(len(concepts) - (self.mult_range_end - self.mult_range_start)))) 

        return (concepts, concept_probs)

In [128]:
def model(max_num, mult_range_start, mult_range_end, interval_range_length, observations={"obs1": 0, "obs2": 0, "obs3": 0}):
    number_game = NumberGame(max_num, mult_range_start, mult_range_end, interval_range_length)
    
    concepts = number_game.concepts
    concept_probs = number_game.concept_probs
    concept_index = int(pyro.sample("c", dist.Categorical(torch.from_numpy(concept_probs))).numpy())
    concept = concepts[concept_index]
    
    # define the likelihood
    print(concept.probs)
    likelihood = dist.Categorical(torch.from_numpy(concept.probs))
    print(likelihood)
    
    # samples
    y1 = pyro.sample("obs1", likelihood, obs=observations["obs1"])
    y2 = pyro.sample("obs2", likelihood, obs=observations["obs2"])
    y3 = pyro.sample("obs3", likelihood, obs=observations["obs3"])
    
    print(y1, y2, y3)
    
    return concept

In [129]:
model(100, 2, 10, 10)

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 0. 0.]
Categorical(probs: torch.Size([101]))
0 0 0


Range [88, 98]: [88 89 90 91 92 93 94 95 96 97 98]

# Guide

In [157]:
class Guide(nn.Module):
    def __init__(self):
        super(Guide, self).__init__()
        self.neural_net = nn.Sequential(
            nn.Linear(3, 10),
            nn.ReLU(),
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 10),
            nn.ReLU(),
            nn.Linear(10, 5),
            nn.ReLU(),
            nn.Linear(5, 1))

    def forward(self, max_num, mult_range_start, mult_range_end, interval_range_length, observations={"obs1": 0, "obs2": 0, "obs3": 0}):
        pyro.module("guide", self)
        obs1 = observations["obs1"]
        obs2 = observations["obs2"]
        obs3 = observations["obs3"]
        v = torch.cat((obs1.view(1, 1), obs2.view(1, 1), obs3.view(1, 1)), 1).float()
        v = self.neural_net(v)

        mean = v[0]
        print("Mean: ", mean)
        std = v[0].exp()
        print("Std: ", std)
        pyro.sample("z", dist.Normal(mean, std))

guide = Guide()

In [158]:
optimiser = pyro.optim.Adam({'lr': 1e-3})
csis = pyro.infer.CSIS(model, guide, optimiser, num_inference_samples=50)

max_num = 100
mult_range_start = 2
mult_range_end = 10
interval_range_length = 10

In [159]:
for step in range(n_steps):
    csis.step(max_num, mult_range_start, mult_range_end, interval_range_length)

[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0.]
Categorical(probs: torch.Size([101]))
tensor(72) tensor(40) tensor(80)
Mean:  tensor([-0.9981], grad_fn=<SelectBackward>)
Std:  tensor([0.3686], grad_fn=<ExpBackward>)


RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.