<a href="https://colab.research.google.com/github/dimidagd/gists/blob/main/Check_for_grad_mix_across_batchipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Simple model for demonstration
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv = nn.Conv2d(1, 10, kernel_size=5)
        self.fc = nn.Linear(10 , 1)
        self.bn = nn.InstanceNorm2d(10)
        self.ln = nn.LayerNorm(10)
        self.freeze_layer(self.fc)
    def forward(self, x):
        x = F.relu(self.conv(x))

        x = self.bn(x)
        x = x.view(x.size(0),-1,x.size(1))
        x = self.ln(x)

        x = self.fc(x)
        return x
    @staticmethod
    def freeze_layer(layer):
        for param in layer.parameters():
          param.requires_grad = False


In [8]:
model = SimpleModel()
model.zero_grad()
import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [9]:
model.fc.weight

Parameter containing:
tensor([[-0.1893, -0.1726, -0.1329,  0.2299, -0.1490, -0.2646,  0.2749,  0.0196,
          0.0782, -0.1371]])

In [13]:
n_batch = 5
for i in range(n_batch):
  X = torch.rand(n_batch,1,50,50,requires_grad=True)
  loss = model(X)
  sample_loss = loss[i]
  sample_loss.sum().backward()
  with torch.no_grad():
    grads = [sample_grad.abs().sum().item() for sample_grad in X.grad]
    print("gradients are",grads)
    grads.pop(i)
    #assert not any(grads)
    old_weights = model.fc.weight.clone().detach()
    optimizer.step()
    new_weights = model.fc.weight.clone().detach()
    print("Difference in layer weights",(old_weights-new_weights).sum())
    model.zero_grad()

gradinets are [14076.8310546875, 0.0, 0.0, 0.0, 0.0]
Difference in layer weights tensor(0.)
gradinets are [0.0, 19610.6640625, 0.0, 0.0, 0.0]
Difference in layer weights tensor(0.)
gradinets are [0.0, 0.0, 23202.52734375, 0.0, 0.0]
Difference in layer weights tensor(0.)
gradinets are [0.0, 0.0, 0.0, 22779.171875, 0.0]
Difference in layer weights tensor(0.)
gradinets are [0.0, 0.0, 0.0, 0.0, 19541.716796875]
Difference in layer weights tensor(0.)
