In [2]:
import torch
from sklearn.linear_model import LogisticRegression
import math

# Data Generation for Experiment 1

In [3]:
N = 100
sigma_inv = 1
sigma_spu = 0.01

Y_1 = torch.cat((torch.ones(N//2), torch.zeros(N//2)))
Y_2 = torch.cat((torch.ones(N//2), torch.zeros(N//2)))
Y_3 = torch.cat((torch.ones(N//2), torch.zeros(N//2)))

Y_train = torch.cat((Y_1, Y_2))

Z_1_1 = torch.normal(2*Y_1-1, sigma_inv)
Z_1_2 = torch.clone(Z_1_1)
Z_1 = torch.cat((Z_1_1, Z_1_2))

Z_1_3 = torch.clone(Z_1_1)
Z_2_1 = torch.normal(2*Y_1-1, sigma_spu)
Z_2_2 = (Z_2_1-(2*Y_1-1))*10+(2*Y_1-1)
Z_2_3 = (Z_2_1-(2*Y_1-1))*1000+(2*Y_1-1)
Z_2 = torch.cat((Z_2_1, Z_2_2))

Z_train = torch.stack((Z_1, Z_2), dim=1)
Z_test = torch.stack((Z_1_3, Z_2_3), dim=1)
Y_test = Y_3

In [6]:
# ERM
clf = LogisticRegression(random_state=0, fit_intercept=False).fit(Z_train, Y_train)
clf.score(Z_test, Y_test)

0.63

In [7]:
# Oracle
clf = LogisticRegression(random_state=0, fit_intercept=False).fit(Z_1.reshape(-1,1), Y_train)
clf.score(Z_test[:,0].reshape(-1,1), Y_test)

0.82

distribution matching is not possible in the above examples.
Let's add a mixing function!

In [8]:
angle = torch.pi / 4

rot_matrix = torch.tensor([[math.cos(angle), -math.sin(angle)], [math.sin(angle), math.cos(angle)]])

X_train = Z_train @ rot_matrix
X_test = Z_test @ rot_matrix

In [9]:
clf = LogisticRegression(random_state=0, fit_intercept=False).fit(X_train, Y_train)

clf.score(X_test, Y_test)

0.63

Now let's move to counterfactual matching and distribution matching
Counterfactual matching

In [10]:
import torch.nn as nn
import torch.optim as optim

class LR(nn.Module):
    def __init__(self, D_in):
        super(LR, self).__init__()
        self.linear = nn.Linear(D_in, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.linear(x)
        x = self.sigmoid(x)
        return x.squeeze()


In [24]:
model = LR(2)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

for repeat in range(20):
    for n in range(1000): 
        y = model(X_train)
        loss = criterion(y, Y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    y_test = model(Z_test)
    print(float(((y_test>0.5)==Y_test).sum()/len(y_test)))

0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437
0.6600000262260437


In [23]:
for repeat in range(20):
    model = LR(2)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.02)

    lmda = 10000
    for n in range(1000): 
        y_1 = model(X_train[0:len(Z_train)//2])
        y_2 = model(X_train[len(Z_train)//2:])
        loss_1 = criterion(torch.cat((y_1,y_2)), Y_train)
        loss_2 = torch.norm(y_1-y_2) / len(y_1)
        # print(loss_1, loss_2)
        loss = loss_1 + lmda * loss_2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    y_test = model(X_test)
    y_train = model(X_train)
    print(float(((y_train>0.5)==Y_train).sum()/len(y_train)), float(((y_test>0.5)==Y_test).sum()/len(y_test)))
    

0.8299999833106995 0.8399999737739563
0.8399999737739563 0.8299999833106995
0.8299999833106995 0.8299999833106995
0.8299999833106995 0.8299999833106995
0.8299999833106995 0.8299999833106995
0.8399999737739563 0.8500000238418579
0.8199999928474426 0.800000011920929
0.8299999833106995 0.8399999737739563
0.8299999833106995 0.8100000023841858
0.8399999737739563 0.8299999833106995
0.8199999928474426 0.800000011920929
0.8299999833106995 0.8100000023841858
0.8399999737739563 0.8399999737739563
0.8199999928474426 0.8199999928474426
0.8399999737739563 0.8399999737739563
0.8199999928474426 0.8199999928474426
0.8299999833106995 0.8399999737739563
0.8399999737739563 0.8399999737739563
0.8399999737739563 0.8399999737739563
0.8299999833106995 0.8399999737739563


In [26]:
import torch.autograd as autograd

scale = torch.tensor(1.).requires_grad_()


def irm_penalty(loss_0, loss_1):
    grad_0 = autograd.grad(loss_0.mean(), [scale], create_graph=True)[0]
    grad_1 = autograd.grad(loss_1.mean(), [scale], create_graph=True)[0]
    result = torch.sum(grad_0 * grad_1)
    del grad_0, grad_1
    return result

for repeat in range(20):
    model = LR(2)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.05)
    penalty_anneal_iters = 100
    penalty_weight = 100

    for n in range(1000): 
        y_1 = model(X_train[0:len(X_train)//2])
        y_2 = model(X_train[len(X_train)//2:])
        loss_1 = criterion(y_1*scale, Y_train[0:len(Z_train)//2])
        loss_2 = criterion(y_2*scale, Y_train[len(Z_train)//2:])
        loss_3 = irm_penalty(loss_1, loss_2)
        # print(n, loss_1 + loss_2 ,loss_3)
        if n < penalty_anneal_iters:
            loss = loss_1 + loss_2
        else:
            loss = loss_1 + loss_2 + penalty_weight * loss_3
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    y_test = model(X_test)
    y_train = model(X_train)
    print(float(((y_train>0.5)==Y_train).sum()/len(y_train)), float(((y_test>0.5)==Y_test).sum()/len(y_test)))

0.9200000166893005 0.6600000262260437
0.8999999761581421 0.6600000262260437
0.8050000071525574 0.6499999761581421
0.8050000071525574 0.6499999761581421
0.8149999976158142 0.6499999761581421
0.8149999976158142 0.6499999761581421
0.8149999976158142 0.6499999761581421
0.8999999761581421 0.6600000262260437
0.8999999761581421 0.6600000262260437
0.8149999976158142 0.6499999761581421
0.9049999713897705 0.6600000262260437
0.8149999976158142 0.6499999761581421
0.9150000214576721 0.6600000262260437
0.8149999976158142 0.6499999761581421
0.8149999976158142 0.6499999761581421
0.9200000166893005 0.6600000262260437
0.8149999976158142 0.6499999761581421
0.8149999976158142 0.6499999761581421
0.9049999713897705 0.6600000262260437
0.8149999976158142 0.6499999761581421


In [16]:
y_test = model(Z_test)
print(((y_test>0.5)==Y_test).sum()/len(y_test))

tensor(0.5500)


# Data Generation for Experiment 2

In this experiment, we set the spurious variable $Z_2$ as the parents of $Z_1$. Furthermore, we need to gurantee that $Z_2$ and $Y$ are independent given $Z_1$. 

Z_2 -> Z_1 -> Y
Z_2 is marginal related to Y.

In [140]:
N = 100
sigma_inv = 0.01
sigma_spu = 1

Y_1 = torch.cat((torch.ones(N//2), torch.zeros(N//2)))
Y_2 = torch.cat((torch.ones(N//2), torch.zeros(N//2)))
Y_3 = torch.cat((torch.ones(N//2), torch.zeros(N//2)))


Z_2_1 = torch.normal(2*Y_1-1, sigma_spu)
Z_2_2 = (Z_2_1-(2*Y_1-1))*0.1+(2*Y_1-1)
Z_2_3 = (Z_2_1-(2*Y_1-1))*0.001+(2*Y_1-1)
Z_2 = torch.cat((Z_2_1, Z_2_2))


exogenous_noise = torch.normal(torch.zeros(N), sigma_inv)
Z_1_1 = Z_2_1 + exogenous_noise
Z_1_2 = Z_2_2 + exogenous_noise
Z_1_3 = Z_2_3 + exogenous_noise
Z_1 = torch.cat((Z_1_1, Z_1_2))

Y_train = ((Z_1>0).float()) 
Y_test = ((Z_1_3>0).float())


Z_train = torch.stack((Z_1, Z_2), dim=1)
Z_test = torch.stack((Z_1_3, Z_2_3), dim=1)


angle = torch.pi / 4

rot_matrix = torch.tensor([[math.cos(angle), -math.sin(angle)], [math.sin(angle), math.cos(angle)]])

X_train = Z_train @ rot_matrix
X_test = Z_test @ rot_matrix

In [141]:
clf = LogisticRegression(random_state=0, fit_intercept=False).fit(X_train, Y_train)

clf.score(X_test, Y_test)

1.0

In [142]:
# Oracle
clf = LogisticRegression(random_state=0, fit_intercept=False).fit(Z_1.reshape(-1,1), Y_train)
clf.score(Z_test[:,0].reshape(-1,1), Y_test)

1.0

In [143]:
model = LR(2)
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

for repeat in range(20):
    for n in range(500): 
        y = model(X_train)
        loss = criterion(y, Y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    y_test = model(X_test)
    print(float(((y_test>0.5)==Y_test).sum()/len(y_test)))

1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0


In [136]:
s = 0.
for repeat in range(20):
    model = LR(2)
    criterion = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.05)

    lmda = 100
    for n in range(500): 
        y_1 = model(X_train[0:len(X_train)//2])
        y_2 = model(X_train[len(X_train)//2:])
        loss_1 = criterion(torch.cat((y_1,y_2)), Y_train)
        loss_2 = torch.norm(y_1-y_2) / len(y_1)
        # print(loss_1, loss_2)
        loss = loss_1 + lmda * loss_2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    y_test = model(X_test)
    s += float(((y_test>0.5)==Y_test).sum()/len(y_test))
    print(float(((y_test>0.5)==Y_test).sum()/len(y_test)))
print("---------------")
print(s/20)

0.9800000190734863
0.9800000190734863
1.0
1.0
1.0
0.9900000095367432
1.0
1.0
1.0
1.0
1.0
1.0
1.0
0.9800000190734863
1.0
1.0
1.0
1.0
1.0
1.0
---------------
0.9965000033378602


In [131]:
import torch.autograd as autograd

scale = torch.tensor(1.).requires_grad_()


def irm_penalty(loss_0, loss_1):
    grad_0 = autograd.grad(loss_0.mean(), [scale], create_graph=True)[0]
    grad_1 = autograd.grad(loss_1.mean(), [scale], create_graph=True)[0]
    result = torch.sum(grad_0 * grad_1)
    del grad_0, grad_1
    return result

s = 0.
for repeat in range(20):
    model = LR(2)
    criterion = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.05)
    penalty_anneal_iters = 100
    penalty_weight = 1

    for n in range(500): 
        y_1 = model(X_train[0:len(Z_train)//2])
        y_2 = model(X_train[len(Z_train)//2:])
        loss_1 = criterion(y_1*scale, Y_train[0:len(Z_train)//2])
        loss_2 = criterion(y_2*scale, Y_train[len(Z_train)//2:])
        loss_3 = irm_penalty(loss_1, loss_2)
        # print(n, loss_1 + loss_2 ,loss_3)
        if n < penalty_anneal_iters:
            loss = loss_1 + loss_2
        else:
            loss = loss_1 + loss_2 + penalty_weight * loss_3
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    y_test = model(X_test)
    s += float(((y_test>0.5)==Y_test).sum()/len(y_test))
    print(float(((y_test>0.5)==Y_test).sum()/len(y_test)))
print("---------------")
print(s/20)

0.9599999785423279
0.9100000262260437
0.9100000262260437
0.9100000262260437
0.9300000071525574
0.9100000262260437
0.9100000262260437
0.9100000262260437
0.9100000262260437
0.9100000262260437
0.9200000166893005
0.9100000262260437
0.9100000262260437
0.9300000071525574
0.9300000071525574
0.9599999785423279
0.8999999761581421
0.8999999761581421
0.8999999761581421
0.9200000166893005
---------------
0.9175000101327896


# Data Generation for Experiment 3

In this experiment, we set the spurious variable $Z_2$ as the parents of $Z_1$. Furthermore, we need to gurantee that $Z_2$ and $Y$ are independent given $Z_1$. 

Z_1->Y->Z_2

In [242]:
N = 10000

sigma_inv = 1

exogenous_noise = torch.normal(torch.zeros(N), sigma_inv)
Z_1_1 = torch.zeros(N) + exogenous_noise
Z_1_2 = torch.ones(N) + exogenous_noise
Z_1_3 = 10*torch.ones(N) + exogenous_noise
Z_1 = torch.cat((Z_1_1, Z_1_2))

Z_2_1 = torch.normal(Z_1_1, sigma_inv)
Z_2_2 = torch.normal(Z_1_2, sigma_inv)
Z_2_3 = torch.normal(Z_1_3, sigma_inv)
Z_2 = torch.cat((Z_2_1, Z_2_2))


Y_train = (torch.sin(2*torch.pi*Z_1) > 0).float()
Y_test = (torch.sin(2*torch.pi*Z_1_3) > 0).float()

Z_train = torch.stack((Z_1, Z_2), dim=1)
Z_test = torch.stack((Z_1_3, Z_2_3), dim=1)

print(X_train.shape, Y_train.shape)
angle = torch.pi / 4

rot_matrix = torch.tensor([[math.cos(angle), -math.sin(angle)], [math.sin(angle), math.cos(angle)]])

X_train = Z_train @ rot_matrix
X_test = Z_test @ rot_matrix

torch.Size([2000, 2]) torch.Size([20000])


In [243]:
clf = LogisticRegression(random_state=0, fit_intercept=True).fit(X_train, Y_train)

clf.score(X_test, Y_test)

0.5109

In [251]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=2,out_features=4),
            nn.ReLU(),
            nn.Linear(in_features=4,out_features=8),
            nn.ReLU(),
            nn.Linear(in_features=8,out_features=16),
            nn.ReLU(),
            nn.Linear(in_features=16,out_features=4),
            nn.ReLU(),
            nn.Linear(in_features=4,out_features=1),
            nn.Sigmoid()
        )

    def forward(self,x):
        return self.model(x).squeeze()

In [252]:
model = CNN()
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

for repeat in range(20):
    for n in range(500): 
        y = model(X_train)
        loss = criterion(y, Y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    y_test = model(X_test)
    print(float(((y_test>0.5)==Y_test).sum()/len(y_test)))

0.48910000920295715
0.48910000920295715
0.48910000920295715


KeyboardInterrupt: 

In [256]:
s = 0.
c = 0.
for repeat in range(1):
    model = CNN()
    criterion = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1)

    lmda = 0
    for n in range(1000): 
        y_1 = model(X_train[0:len(X_train)//2])
        y_2 = model(X_train[len(X_train)//2:])
        loss_1 = criterion(torch.cat((y_1,y_2)), Y_train)
        loss_2 = torch.norm(y_1-y_2) / len(y_1)
        print(n, loss_1, loss_2)
        loss = loss_1 + lmda * loss_2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    y_train = model(X_train)
    c += float(((y_train>0.5)==Y_train).sum()/len(y_train))
    print(float(((y_train>0.5)==Y_train).sum()/len(y_train)))
    y_test = model(X_test)
    s += float(((y_test>0.5)==Y_test).sum()/len(y_test))
    print(float(((y_test>0.5)==Y_test).sum()/len(y_test)))
print("---------------")
print(s/1)

0 tensor(0.7246, grad_fn=<BinaryCrossEntropyBackward0>) tensor(1.6035e-05, grad_fn=<DivBackward0>)
1 tensor(0.7218, grad_fn=<BinaryCrossEntropyBackward0>) tensor(1.9683e-05, grad_fn=<DivBackward0>)
2 tensor(0.7192, grad_fn=<BinaryCrossEntropyBackward0>) tensor(2.4394e-05, grad_fn=<DivBackward0>)
3 tensor(0.7169, grad_fn=<BinaryCrossEntropyBackward0>) tensor(2.9532e-05, grad_fn=<DivBackward0>)
4 tensor(0.7147, grad_fn=<BinaryCrossEntropyBackward0>) tensor(3.4793e-05, grad_fn=<DivBackward0>)
5 tensor(0.7128, grad_fn=<BinaryCrossEntropyBackward0>) tensor(4.0019e-05, grad_fn=<DivBackward0>)
6 tensor(0.7110, grad_fn=<BinaryCrossEntropyBackward0>) tensor(4.5050e-05, grad_fn=<DivBackward0>)
7 tensor(0.7094, grad_fn=<BinaryCrossEntropyBackward0>) tensor(4.9701e-05, grad_fn=<DivBackward0>)
8 tensor(0.7079, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.4143e-05, grad_fn=<DivBackward0>)
9 tensor(0.7066, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.8380e-05, grad_fn=<DivBackward0>)
10 tensor(

85 tensor(0.6932, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
86 tensor(0.6932, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
87 tensor(0.6932, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
88 tensor(0.6932, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
89 tensor(0.6932, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
90 tensor(0.6932, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
91 tensor(0.6932, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
92 tensor(0.6932, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
93 tensor(0.6932, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
94 tensor(0.6932, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
95 tensor(0.6932, grad_fn=<BinaryCrossEn

171 tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
172 tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
173 tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
174 tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
175 tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
176 tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
177 tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
178 tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
179 tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
180 tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
181 tensor(0.6931, grad_fn=<Bi

258 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
259 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
260 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(0.0001, grad_fn=<DivBackward0>)
261 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.9941e-05, grad_fn=<DivBackward0>)
262 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.9828e-05, grad_fn=<DivBackward0>)
263 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.9710e-05, grad_fn=<DivBackward0>)
264 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.9597e-05, grad_fn=<DivBackward0>)
265 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.9482e-05, grad_fn=<DivBackward0>)
266 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.9368e-05, grad_fn=<DivBackward0>)
267 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.9255e-05, grad_fn=<DivBackward0>)
26

345 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.1557e-05, grad_fn=<DivBackward0>)
346 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.1471e-05, grad_fn=<DivBackward0>)
347 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.1384e-05, grad_fn=<DivBackward0>)
348 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.1302e-05, grad_fn=<DivBackward0>)
349 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.1215e-05, grad_fn=<DivBackward0>)
350 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.1131e-05, grad_fn=<DivBackward0>)
351 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.1048e-05, grad_fn=<DivBackward0>)
352 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.0963e-05, grad_fn=<DivBackward0>)
353 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.0880e-05, grad_fn=<DivBackward0>)
354 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(9.0796e-05, grad_fn=<DivBa

432 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(8.4806e-05, grad_fn=<DivBackward0>)
433 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(8.4735e-05, grad_fn=<DivBackward0>)
434 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(8.4664e-05, grad_fn=<DivBackward0>)
435 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(8.4593e-05, grad_fn=<DivBackward0>)
436 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(8.4522e-05, grad_fn=<DivBackward0>)
437 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(8.4452e-05, grad_fn=<DivBackward0>)
438 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(8.4382e-05, grad_fn=<DivBackward0>)
439 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(8.4312e-05, grad_fn=<DivBackward0>)
440 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(8.4241e-05, grad_fn=<DivBackward0>)
441 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(8.4171e-05, grad_fn=<DivBa

519 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.9071e-05, grad_fn=<DivBackward0>)
520 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.9008e-05, grad_fn=<DivBackward0>)
521 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.8949e-05, grad_fn=<DivBackward0>)
522 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.8887e-05, grad_fn=<DivBackward0>)
523 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.8829e-05, grad_fn=<DivBackward0>)
524 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.8768e-05, grad_fn=<DivBackward0>)
525 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.8710e-05, grad_fn=<DivBackward0>)
526 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.8648e-05, grad_fn=<DivBackward0>)
527 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.8589e-05, grad_fn=<DivBackward0>)
528 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.8528e-05, grad_fn=<DivBa

605 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.4191e-05, grad_fn=<DivBackward0>)
606 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.4139e-05, grad_fn=<DivBackward0>)
607 tensor(0.6930, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.4085e-05, grad_fn=<DivBackward0>)
608 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.4033e-05, grad_fn=<DivBackward0>)
609 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.3979e-05, grad_fn=<DivBackward0>)
610 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.3926e-05, grad_fn=<DivBackward0>)
611 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.3873e-05, grad_fn=<DivBackward0>)
612 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.3821e-05, grad_fn=<DivBackward0>)
613 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.3768e-05, grad_fn=<DivBackward0>)
614 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(7.3716e-05, grad_fn=<DivBa

693 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.9866e-05, grad_fn=<DivBackward0>)
694 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.9820e-05, grad_fn=<DivBackward0>)
695 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.9774e-05, grad_fn=<DivBackward0>)
696 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.9729e-05, grad_fn=<DivBackward0>)
697 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.9683e-05, grad_fn=<DivBackward0>)
698 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.9637e-05, grad_fn=<DivBackward0>)
699 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.9591e-05, grad_fn=<DivBackward0>)
700 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.9545e-05, grad_fn=<DivBackward0>)
701 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.9499e-05, grad_fn=<DivBackward0>)
702 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.9454e-05, grad_fn=<DivBa

779 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.6159e-05, grad_fn=<DivBackward0>)
780 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.6119e-05, grad_fn=<DivBackward0>)
781 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.6079e-05, grad_fn=<DivBackward0>)
782 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.6039e-05, grad_fn=<DivBackward0>)
783 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.5999e-05, grad_fn=<DivBackward0>)
784 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.5959e-05, grad_fn=<DivBackward0>)
785 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.5920e-05, grad_fn=<DivBackward0>)
786 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.5880e-05, grad_fn=<DivBackward0>)
787 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.5841e-05, grad_fn=<DivBackward0>)
788 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.5800e-05, grad_fn=<DivBa

866 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.2913e-05, grad_fn=<DivBackward0>)
867 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.2879e-05, grad_fn=<DivBackward0>)
868 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.2843e-05, grad_fn=<DivBackward0>)
869 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.2807e-05, grad_fn=<DivBackward0>)
870 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.2772e-05, grad_fn=<DivBackward0>)
871 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.2737e-05, grad_fn=<DivBackward0>)
872 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.2702e-05, grad_fn=<DivBackward0>)
873 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.2668e-05, grad_fn=<DivBackward0>)
874 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.2633e-05, grad_fn=<DivBackward0>)
875 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.2599e-05, grad_fn=<DivBa

954 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(6.0016e-05, grad_fn=<DivBackward0>)
955 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.9985e-05, grad_fn=<DivBackward0>)
956 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.9955e-05, grad_fn=<DivBackward0>)
957 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.9924e-05, grad_fn=<DivBackward0>)
958 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.9893e-05, grad_fn=<DivBackward0>)
959 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.9863e-05, grad_fn=<DivBackward0>)
960 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.9834e-05, grad_fn=<DivBackward0>)
961 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.9803e-05, grad_fn=<DivBackward0>)
962 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.9773e-05, grad_fn=<DivBackward0>)
963 tensor(0.6929, grad_fn=<BinaryCrossEntropyBackward0>) tensor(5.9742e-05, grad_fn=<DivBa

In [213]:
import torch.autograd as autograd

scale = torch.tensor(1.).requires_grad_()


def irm_penalty(loss_0, loss_1):
    grad_0 = autograd.grad(loss_0.mean(), [scale], create_graph=True)[0]
    grad_1 = autograd.grad(loss_1.mean(), [scale], create_graph=True)[0]
    result = torch.sum(grad_0 * grad_1)
    del grad_0, grad_1
    return result

s = 0.
for repeat in range(20):
    model = CNN()
    criterion = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.05)
    penalty_anneal_iters = 100
    penalty_weight = 1

    for n in range(500): 
        y_1 = model(X_train[0:len(Z_train)//2])
        y_2 = model(X_train[len(Z_train)//2:])
        loss_1 = criterion(y_1*scale, Y_train[0:len(Z_train)//2])
        loss_2 = criterion(y_2*scale, Y_train[len(Z_train)//2:])
        loss_3 = irm_penalty(loss_1, loss_2)
        # print(n, loss_1 + loss_2 ,loss_3)
        if n < penalty_anneal_iters:
            loss = loss_1 + loss_2
        else:
            loss = loss_1 + loss_2 + penalty_weight * loss_3
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    y_test = model(X_test)
    s += float(((y_test>0.5)==Y_test).sum()/len(y_test))
    print(float(((y_test>0.5)==Y_test).sum()/len(y_test)))
print("---------------")
print(s/20)

0.5
0.5
0.5400000214576721
0.5
0.5
0.5
0.5
0.5
0.5
0.5
0.5
0.5
0.5
0.5
0.49000000953674316
0.5
0.5
0.5
0.5
0.5
---------------
0.5015000015497207
