This notebook is a supplement to 3-6 inference, and it mostly just ensures torch is working as expected.

In [1]:
from abc import ABC, abstractmethod

import os
from tqdm import tqdm
import math 
import time

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 

from scipy import stats

import torch
from torch.distributions import Beta
from torch.distributions.bernoulli import Bernoulli
from torch.nn.functional import log_softmax
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader, TensorDataset

In [2]:
SEED = 1
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7ff6a56b3550>

In [3]:
"""
Test 1- check if .log_prob() broadcasts across a tensor
"""
test = torch.tensor(
    [
    [1., 0., 1.],
    [1., 0., 1.]
    ]
)
test_dist = Bernoulli(torch.tensor([0.7]))
torch.exp(test_dist.log_prob(test))

tensor([[0.7000, 0.3000, 0.7000],
        [0.7000, 0.3000, 0.7000]])

In [4]:
"""
Test 2- given W (size [3]), u (size [B, G]), get Wu

Really, what we want to do is create a [B,G,3] tensor that is 
W elementwise multiplied with each element of u.
"""
B, G = 2, 5
W = torch.tensor([1., 2., 3.])
u_mat = torch.ones([B, G])

print(f'u_mat :\n {u_mat}')

W_exp = W.unsqueeze(0).unsqueeze(0)
print(f'W_exp:\n {W_exp}')
print(f'W_exp size: {W_exp.size()}')

u_exp = u_mat.unsqueeze(2)
print(f'u_exp[:1]:\n {u_exp[:1]}')
print(f'u_exp size:\n {u_exp.size()}')

Wu = torch.matmul(u_exp, W_exp)

assert list(Wu.size()) == [B, G, 3]
print(Wu)

u_mat :
 tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
W_exp:
 tensor([[[1., 2., 3.]]])
W_exp size: torch.Size([1, 1, 3])
u_exp[:1]:
 tensor([[[1.],
         [1.],
         [1.],
         [1.],
         [1.]]])
u_exp size:
 torch.Size([2, 5, 1])
tensor([[[1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.]],

        [[1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.]]])


In [5]:
"""
Test 3- We now need do compute dot(x, log Softmax(Wu))

The softmax should go along dim 2 (over 3 elements), resulting in a [B, G, 3] tensor 
"""

log_sm = log_softmax(Wu, dim=2)
log_sm


tensor([[[-2.4076, -1.4076, -0.4076],
         [-2.4076, -1.4076, -0.4076],
         [-2.4076, -1.4076, -0.4076],
         [-2.4076, -1.4076, -0.4076],
         [-2.4076, -1.4076, -0.4076]],

        [[-2.4076, -1.4076, -0.4076],
         [-2.4076, -1.4076, -0.4076],
         [-2.4076, -1.4076, -0.4076],
         [-2.4076, -1.4076, -0.4076],
         [-2.4076, -1.4076, -0.4076]]])

In [10]:
"""
Test 3- dot(x, log Softmax(Wu))

Now that we have log softmax, we need to dot x with log_sm
Looking above, notice we have 2 x values (batch size is 2)
The first 5x3 matrix should be dotted with x1 = {x11, x12, x13},
whereas the second should be dotted with x2 = {x21, x21, x32}

and then we need to sum across dim 2 (the dim with 3 elements)
"""

'\nTest 3- dot(x, log Softmax(Wu))\n\nNow that we have log softmax, we need to dot x with log_sm\nLooking above, notice we have 2 x values (batch size is 2)\nThe first 5x3 matrix should be dotted with x1 = {x11, x12, x13},\nwhereas the second should be dotted with x2 = {x21, x21, x32}\n\nand then we need to sum across dim 2 (the dim with 3 elements)\n'

In [6]:
x_batch= torch.tensor([
    [7., 3., 5.],
    [12., 2., 1.]
])

x_exp = x_batch.unsqueeze(1)
print(x_exp)

print(x_exp * log_softmax(Wu, dim=2))
res = (x_exp * log_softmax(Wu, dim=2)).sum(dim=2)
print(res)

assert list(res.size()) == [B, G]

tensor([[[ 7.,  3.,  5.]],

        [[12.,  2.,  1.]]])
tensor([[[-16.8532,  -4.2228,  -2.0380],
         [-16.8532,  -4.2228,  -2.0380],
         [-16.8532,  -4.2228,  -2.0380],
         [-16.8532,  -4.2228,  -2.0380],
         [-16.8532,  -4.2228,  -2.0380]],

        [[-28.8913,  -2.8152,  -0.4076],
         [-28.8913,  -2.8152,  -0.4076],
         [-28.8913,  -2.8152,  -0.4076],
         [-28.8913,  -2.8152,  -0.4076],
         [-28.8913,  -2.8152,  -0.4076]]])
tensor([[-23.1141, -23.1141, -23.1141, -23.1141, -23.1141],
        [-32.1141, -32.1141, -32.1141, -32.1141, -32.1141]])


In [7]:
"""
Test 4- adding constants to log likelihood

We need to add log(S!), which is easy
We also need to - log(x1!) - log(x2!) - log(x3!)
"""
factorial = lambda x : torch.exp(torch.lgamma(x+1))

### add log(S!)
res += torch.log(factorial(torch.tensor(15.)))

print(f'log(factorial(x_batch)):\n {torch.log(factorial(x_batch))}')
other = torch.log(factorial(x_batch)).sum(dim=1, keepdim=True)
print(other)

res -= other
res

log(factorial(x_batch)):
 tensor([[ 8.5252,  1.7918,  4.7875],
        [19.9872,  0.6931,  0.0000]])
tensor([[15.1044],
        [20.6804]])


tensor([[-10.3192, -10.3192, -10.3192, -10.3192, -10.3192],
        [-24.8952, -24.8952, -24.8952, -24.8952, -24.8952]])

In [8]:
"""
Test 5- discretizing space depending on y

Given y_batch with size [batch_size], create a
[batch_size, grid_size] matrix, where each row is
grid_size linearly spaced points from
    (0, 1) if y_batch[row] == 1
    (-1, 0) if y_batch[row] == 0
"""
B = 10
G = 7
u_mat = torch.empty(B, G)

y_batch = torch.tensor([1., 0., 1., 1., 0., 1., 0., 0., 1., 0.])

u_mat[y_batch == 1] = torch.linspace(1/(G+1), 1-1/(G+1), G).repeat((y_batch == 1).sum(), 1)
u_mat[y_batch == 0] = torch.linspace(-1+1/(G+1), -1/(G+1), G).repeat((y_batch == 0).sum(), 1)

torch.hstack([u_mat, y_batch.unsqueeze(1)])

tensor([[ 0.1250,  0.2500,  0.3750,  0.5000,  0.6250,  0.7500,  0.8750,  1.0000],
        [-0.8750, -0.7500, -0.6250, -0.5000, -0.3750, -0.2500, -0.1250,  0.0000],
        [ 0.1250,  0.2500,  0.3750,  0.5000,  0.6250,  0.7500,  0.8750,  1.0000],
        [ 0.1250,  0.2500,  0.3750,  0.5000,  0.6250,  0.7500,  0.8750,  1.0000],
        [-0.8750, -0.7500, -0.6250, -0.5000, -0.3750, -0.2500, -0.1250,  0.0000],
        [ 0.1250,  0.2500,  0.3750,  0.5000,  0.6250,  0.7500,  0.8750,  1.0000],
        [-0.8750, -0.7500, -0.6250, -0.5000, -0.3750, -0.2500, -0.1250,  0.0000],
        [-0.8750, -0.7500, -0.6250, -0.5000, -0.3750, -0.2500, -0.1250,  0.0000],
        [ 0.1250,  0.2500,  0.3750,  0.5000,  0.6250,  0.7500,  0.8750,  1.0000],
        [-0.8750, -0.7500, -0.6250, -0.5000, -0.3750, -0.2500, -0.1250,  0.0000]])

In [23]:
"""
Test 6- getting posterior from log_joint (size [B, G])

First we need to exponentiate (to get joint), and then normalize 
The resulting matrix is size [B, G], where each row is the posterior
for that x^{(n)}, y^{(n)} pair
"""

log_joint = torch.normal(0, 5, size=(2, 3))

print(f'Log Joint:\n {log_joint}')
print(f'Normalized Posterior: \n{torch.exp(log_joint) / torch.exp(log_joint).sum(dim=1, keepdim=True)}')


Log Joint:
 tensor([[-7.6391,  1.9458, -3.7146],
        [-0.2663,  2.2826,  3.0587]])
Normalized Posterior: 
tensor([[6.8515e-05, 9.9646e-01, 3.4687e-03],
        [2.4044e-02, 3.0758e-01, 6.6838e-01]])


In [24]:
### we can also use log sum exp to get the same result
max_log_prob = torch.max(log_joint, dim=1, keepdim=True)[0]
print(max_log_prob)
joint_probs = torch.exp(log_joint - max_log_prob)
posterior = joint_probs / joint_probs.sum(dim=1, keepdim=True)  
posterior

tensor([[1.9458],
        [3.0587]])


tensor([[6.8515e-05, 9.9646e-01, 3.4687e-03],
        [2.4044e-02, 3.0758e-01, 6.6837e-01]])

In [26]:
### detach (code from ChatGPT)

# Correcting the approach to avoid in-place operations and properly manage gradient tracking

results_with_grad = []
results_with_detach = []
results_without_detach = []

for N in range(1, 101, 5):
    # Reinitialize parameters for each N to clear gradients
    theta1 = torch.tensor(3.2, requires_grad=True)
    theta2 = torch.tensor(1.2, requires_grad=True)
    
    gradient_contributions = []
    detached_contributions = []
    without_detach_contributions = []

    for n in range(1, N + 1):
        sin_term = torch.sin(theta1.pow(2) + theta2) / n
        cos_term = torch.cos(theta1 + theta2.pow(3))
        
        # Compute gradient contributions
        cos_term.backward(retain_graph=True)
        grad_with_grad = [theta1.grad.clone(), theta2.grad.clone()]
        gradient_contributions.append(sin_term * torch.stack(grad_with_grad))
        theta1.grad.zero_()
        theta2.grad.zero_()
        
        # Detached sin term
        sin_term_detached = sin_term.detach()
        detached_contributions.append(sin_term_detached * cos_term)

        # Without detaching
        without_detach_contributions.append(sin_term * cos_term)

    # Sum contributions for gradients
    sum_with_grad = torch.stack([sum([gc[i] for gc in gradient_contributions]) for i in range(2)])

    # Sum contributions for detached and without detach
    sum_with_detach = sum(detached_contributions)
    sum_without_detach = sum(without_detach_contributions)

    # Compute gradients
    sum_with_detach.backward(retain_graph=True)
    detached_grads = [theta1.grad.item(), theta2.grad.item()]
    
    theta1.grad.zero_()
    theta2.grad.zero_()
    
    sum_without_detach.backward()
    without_detach_grads = [theta1.grad.item(), theta2.grad.item()]

    # Store results
    results_with_grad.append((N, sum_with_grad.detach().numpy()))
    results_with_detach.append((N, detached_grads))
    results_without_detach.append((N, without_detach_grads))

# Print corrected results without in-place operation issues
for i, N in enumerate(range(1, 101, 5)):
    print(f"N={N}: With Grad={results_with_grad[i][1]}, With Detach={results_with_detach[i][1]}, Without Detach={results_without_detach[i][1]}")


N=1: With Grad=[-0.88195246 -3.8100348 ], With Detach=[-0.8819524645805359, -3.81003475189209], Without Detach=[-0.2932586669921875, -3.7180514335632324]
N=6: With Grad=[-2.1607833 -9.334585 ], With Detach=[-2.160783529281616, -9.334585189819336], Without Detach=[-0.7184838652610779, -9.109225273132324]
N=11: With Grad=[ -2.6633883 -11.505837 ], With Detach=[-2.663388252258301, -11.505838394165039], Without Detach=[-0.8856051564216614, -11.228058815002441]
N=16: With Grad=[ -2.9816425 -12.880694 ], With Detach=[-2.981642246246338, -12.880696296691895], Without Detach=[-0.9914281964302063, -12.569723129272461]
N=21: With Grad=[ -3.2150335 -13.888944 ], With Detach=[-3.2150330543518066, -13.888943672180176], Without Detach=[-1.069032907485962, -13.553629875183105]
N=26: With Grad=[ -3.3994153 -14.6854725], With Detach=[-3.3994150161743164, -14.685473442077637], Without Detach=[-1.1303420066833496, -14.33092975616455]
N=31: With Grad=[ -3.551839 -15.343944], With Detach=[-3.55183863639831