In [105]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from copy import deepcopy

n_epochs = 5 # number of epochs for training
batch_size_train = 1024 # batch size for training
batch_size_test = 8192 # batch size for testing
learning_rate = 0.001 # learning rate for Adam
log_interval = 10 # logging interval for metrics
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device for computation

# fixing random seed
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f7b59313bf0>

In [106]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.f1 = F1()
        self.f2 = F2()
    
    def forward(self, x, y_hat, return_mid=False):
        z = self.f1(x)
        y_pred, z_pred  = self.f2(z, y_hat)
        return y_pred, z_pred
    
    
class F1(nn.Module):
    def __init__(self):
        super(F1, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.activation = nn.LeakyReLU(negative_slope=0.01)

    def forward(self, x):
        x = self.activation(F.max_pool2d(self.conv1(x), 2))
        x = self.activation(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = self.activation(self.fc1(x))
        return x
    
    
class F2(nn.Module):
    def __init__(self):
        super(F2, self).__init__()
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 60)
        self.y_hat_fc = nn.Linear(10, 50)
        self.activation = nn.LeakyReLU(negative_slope=0.01)

    def forward(self, x, y_hat):
        x = x + self.activation(self.y_hat_fc(y_hat))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        return x[:, :10].softmax(-1), x[:, 10:]

# Model Training

In [107]:
def js_div(p, q):
    m = 0.5 * (p + q)
    return 0.5 * (F.kl_div(torch.log(p), m, reduction='batchmean') + 
                  F.kl_div(torch.log(q), m, reduction='batchmean'))

def train(f, train_loader, optimizer, n_epochs, n_classes=10):
  f.train()
  for epoch in range(1, n_epochs + 1):
    for batch_idx, (data, target) in enumerate(train_loader):
      # randomally choose y0 to be either the true target y or a random class
      x, y = data.to(device), target.to(device)
      y0_d = y if torch.rand(1) > 0.25 else torch.randint(0, n_classes, (x.shape[0],), device=device)
      y0 = F.one_hot(y0_d, num_classes=n_classes).float()

      # forward pass
      y1, z1 = f(x, y0)
      y2, z2 = f.f2(z1, y1)

      # losses
      loss_supervised_1 = F.nll_loss(y1.log(), y)
      loss_supervised_2 = F.nll_loss(y2.log(), y)
      loss_unsupervised_y = js_div(y1, y2)*10
      loss_unsupervised_z = (z1 - z2).pow(2).mean()*10
      # loss = loss_supervised + #loss_unsupervised_y # + loss_unsupervised_z
      loss = loss_supervised_1 + loss_supervised_2

      # opt
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      # log
      if batch_idx % log_interval == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tloss_supervised_1: {:.6f}, loss_supervised_2: {:.6f}'.format(
          epoch, batch_idx * len(data), len(train_loader.dataset),
          100. * batch_idx / len(train_loader), loss_supervised_1.item(), loss_supervised_2.item()))
        torch.save(f.state_dict(), './model.pth')
        torch.save(optimizer.state_dict(), './optimizer.pth')
      

def test(f, test_loader, n_classes=10, with_true_y=False):
  f.eval()
  test_loss, correct = 0, 0
  with torch.no_grad():
    for data, target in test_loader:
      x, y = data.to(device), target.to(device)
      if with_true_y:
        y0 = F.one_hot(y, num_classes=n_classes).float()
      else:
        y0 = torch.ones((len(y), n_classes), device=x.device).to(device).softmax(-1)
      y1, z1 = f(x, y0)
      test_loss += F.nll_loss(y1.log(), y, size_average=False).item()
      pred = y1.log().data.max(1, keepdim=True)[1]
      correct += pred.eq(y.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  print("True y given") if with_true_y else print("Max entropy given")
  print('\nTest: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))
  return 100 * correct / len(test_loader.dataset)

In [108]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

f = Net().to(device)
optimizer = optim.Adam(f.parameters(), lr=1e-3)

test(f, test_loader)
train(f, train_loader, optimizer, n_epochs)
test(f, test_loader)
test(f, test_loader, with_true_y=True)

Max entropy given

Test: Avg. loss: 2.3083, Accuracy: 1116/10000 (11%)

Max entropy given

Test: Avg. loss: 0.0419, Accuracy: 9869/10000 (99%)

True y given

Test: Avg. loss: 0.0078, Accuracy: 9978/10000 (100%)



tensor(99.7800)

## ITTT

In [109]:
from tqdm import tqdm 

def ttt_one_instance(x, f_ttt, f, optimizer, n_steps, y, n_classes=10):
  f_ttt.load_state_dict(f.state_dict())  # reset f_ttt to f
  f_ttt.train()
  for step in range(n_steps): 
      
    y0 = F.one_hot(torch.randint(0, n_classes, (x.shape[0],), device=device), num_classes=n_classes).float()
    y1, z1 = f_ttt(x, y0)
    y2, z2 = f.f2(z1, y1)
    
    loss_unsupervised_y = js_div(y1, y2)
    loss_unsupervised_z = (z1 - z2).pow(2).mean()
    loss = loss_unsupervised_y #+ loss_unsupervised_z
        
    optimizer.zero_grad()
    loss.backward()
    if f_ttt.f2.fc2.weight.grad.var() == 0:
        pass
    optimizer.step()
  return y1, y2

def ttt(f, test_loader, n_steps, lr):
  f_ttt = deepcopy(f)
  f.eval()
  optimizer = optim.Adam(f_ttt.parameters(), lr=lr)
  test_loss_1, correct_1 = 0, 0
  test_loss_2, correct_2 = 0, 0
    
  for ind, (data, target) in tqdm(enumerate(test_loader), total=len(test_loader)):
    # print(f'batch {ind}/{len(test_loader)}:')
    x, y = data.to(device), target.to(device)
    y_hat_1, y_hat_2 = ttt_one_instance(x, f_ttt, f, optimizer, n_steps, y)

    test_loss_1 += F.nll_loss(y_hat_1.log(), y, size_average=False).item()
    test_loss_2 += F.nll_loss(y_hat_2.log(), y, size_average=False).item()

    pred_1 = y_hat_1.data.max(1, keepdim=True)[1]
    pred_2 = y_hat_2.data.max(1, keepdim=True)[1]

    correct_1 += pred_1.eq(y.data.view_as(pred_1)).sum()
    correct_2 += pred_2.eq(y.data.view_as(pred_2)).sum()

  return correct_2 / len(test_loader.dataset)

In [111]:
class AddGaussianNoise(object):
    def __init__(self, std, mean=0.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    AddGaussianNoise(1.75)])
ood_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=64, shuffle=False)

In [112]:
test(f, ood_loader)

Max entropy given

Test: Avg. loss: 0.7814, Accuracy: 7380/10000 (74%)



tensor(73.8000)

In [None]:
ttt(f, ood_loader, n_steps=10, lr=1e-3)

  2%|███▎                                                                                                                                                                       | 3/157 [00:00<00:19,  8.02it/s]