This notebook is a supplement to 3-6 inference, and it mostly just ensures torch is working as expected.

In [1]:
from abc import ABC, abstractmethod

import os
from tqdm import tqdm
import math 
import time

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 

from scipy import stats

import torch
from torch.distributions import Beta
from torch.distributions.bernoulli import Bernoulli
from torch.nn.functional import log_softmax
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader, TensorDataset

In [2]:
SEED = 1
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7fae7b577350>

In [3]:
"""
Checking if .log_prob() broadcasts across a tensor
"""
test = torch.tensor(
    [
    [1., 0., 1.],
    [1., 0., 1.]
    ]
)
test_dist = Bernoulli(torch.tensor([0.7]))
torch.exp(test_dist.log_prob(test))

tensor([[0.7000, 0.3000, 0.7000],
        [0.7000, 0.3000, 0.7000]])

In [4]:
"""
Given an NxM matrix and an Bx1 vector, create a tensor size
NxMxB which is the vector multiplied by each of the matrix values 
"""
test_u = torch.tensor(
    [
    [1., 2., 2.4],
    [1., 3., 1.1]
    ]
)
test_W = torch.tensor([1.,1.,2.])

test_res = torch.matmul(test_u.unsqueeze(2), test_W.unsqueeze(0).unsqueeze(0))

assert torch.all(test_res[0, 1] == torch.tensor([2., 2., 4.]))

In [6]:
"""
Numerically stable factorial
Uses gamma function
"""
factorial = lambda x : torch.exp(torch.lgamma(x+1))

assert factorial(torch.tensor(1.)) == torch.tensor(1.)
assert factorial(torch.tensor(2.)) == torch.tensor(2.)
assert factorial(torch.tensor(3.)) == torch.tensor(6.)
print(f"2.9999! = {factorial(torch.tensor(2.9999))}")

2.9999! = 5.999247074127197
