In [1]:
import torch
import torch.nn as nn
import torch.functional as F

import pyro
import pyro.distributions as dist
import pyro.infer
import pyro.optim

# Concept

In [2]:
class Concept:
    
    def __init__(self, mult, params):
        """
        :param mult: boolean flag that is True if concept is multiple, False if concept is range
        :param params: list of params - [mult] if multiple concept, [start, end] if range concept
        """
        self.mult = mult
        self.params = params
        self._numbers = self._populate_numbers()
        self._probs = np.zeros(101)
        self._probs[self._numbers] = 1
    
    
    def _populate_numbers(self):
        """
        Fill in the number set for the concept
        """
        if self.mult:
            if (self.params[0] < 0) or (self.params[0] > 100):
                raise Exception('Invalid multiple for concept.')
            
            if not isinstance(self.params[0], int):
                raise Exception('Multiple must be an integer.')
                
            numbers = np.arange(self.params[0], 101, self.params[0])
            return numbers
        else:
            start, end = self.params[0], self.params[1]
            
            if (start < 0) or (end > 100) or (start > end):
                raise Exception('Invalid interval for concept.')
            
            if not (isinstance(start, int) and isinstance(end, int)):
                raise Exception('Start and end must be integers.')
                
            numbers = np.arange(start, end+1)
            return numbers
        
        
    def sample(self, num_samples):
        """
        Sample num_samples observations from the concept
        """
        samples = []

        for i in range(num_samples):
            sample = random.choice(self.numbers)
            samples.append(sample)

        return samples
    
    
    def __repr__(self):
        if self.mult:
            return 'Multiples of {}: {}'.format(self.params[0], self.numbers)
        else:
            return 'Range [{}, {}]: {}'.format(self.params[0], self.params[1], self.numbers)
    
    @property
    def numbers(self):
        return self._numbers

    @property
    def probs(self):
        return self._probs
    
    def __len__(self):
        return self._numbers.shape[0]

# Generative model