In [None]:
import wget
import os
if not os.path.exists("colour_mnist.py"):
    wget.download("https://raw.githubusercontent.com/clovaai/rebias/master/datasets/colour_mnist.py", "colour_mnist.py")

if not os.path.exists("flac.py"):
    wget.download("https://raw.githubusercontent.com/gsarridis/FLAC/main/flac.py", "flac.py")

In [None]:
from colour_mnist import get_biased_mnist_dataloader
trainloader = get_biased_mnist_dataloader(root='./data', batch_size=128, data_label_correlation=0.99, train=True)
testloader = get_biased_mnist_dataloader(root='./data', batch_size=128, data_label_correlation=0.1, train=False)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        # https://github.com/gsarridis/FLAC/blob/main/models/simple_conv.py
        super().__init__()
        dims = 6
        self.conv1 = nn.Conv2d(3, dims, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(dims, 16, 5)
        self.fc = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.normalize(x, dim=1)
        features = x
        x = self.fc(x)
        return x, features

In [None]:
import torch.optim as optim
import flac

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

color_net = Net()
color_net.to(device)
lossfunc = nn.CrossEntropyLoss()
optimizer = optim.Adam(color_net.parameters(), lr=0.001)

for epoch in range(1):
    epoch_loss = 0
    batches = 0
    for data in trainloader:
        optimizer.zero_grad()
        inputs, sensitive = data[0].to(device), data[2].to(device)
        outputs, features = color_net(inputs)
        loss = lossfunc(outputs, sensitive)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        batches += 1
        print(f'\r[{epoch + 1}] sensitive attribute classifier loss: {epoch_loss/batches:.3f}', end='')
print("\nFinished training sensitive attribute classifier")

In [None]:
def train(net, color_net=None, color_reg=100, epochs=5):
    lossfunc = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    if color_net is not None:
        color_net.eval()
    for epoch in range(epochs):
        epoch_loss = 0
        batches = 0
        for data in trainloader:
            optimizer.zero_grad()
            inputs, labels, sensitive = data[0].to(device), data[1].to(device), data[2].to(device)
            outputs, features = net(inputs)
            loss = lossfunc(outputs, labels)
            if color_net is not None:
                with torch.no_grad():
                    color_outputs, color_features = color_net(inputs)
                loss += color_reg*flac.flac_loss(color_features, features, torch.squeeze(labels), device=device)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            batches += 1
            print(f'\r[{epoch + 1}] loss: {epoch_loss/batches:.3f}', end='')
    print("\nFinished training")

In [None]:
net = Net()
net.to(device)
train(net)
torch.save(net.state_dict(), './data/model.pth')

In [None]:
fair_net = Net()
fair_net.to(device)
train(fair_net, color_net)
torch.save(fair_net.state_dict(), './data/fair_model.pth')