Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Expected all tensors to be on the same device" When Creating LIF Neurons #316

Closed
kt-13 opened this issue Apr 19, 2024 · 8 comments · Fixed by #317
Closed

"Expected all tensors to be on the same device" When Creating LIF Neurons #316

kt-13 opened this issue Apr 19, 2024 · 8 comments · Fixed by #317

Comments

@kt-13
Copy link

kt-13 commented Apr 19, 2024

  • snntorch version: 0.9.0
  • Python version:
  • Operating System: Colab Linux

Description

I get an error telling me that all tensors must be on the same device when I try to create a new model on a GPU. Below is the code I am using. It seems to be a similar issue to the one here #225. If you manually set the device for each leaky object, like I did in the commented out lines, it fixes the issue.

What I Did

class VGG7_SNN(nn.Module):
    def __init__(self, beta, num_classes=10):
        super(VGG7_SNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(3)
        self.leaky1 = snn.Leaky(beta=beta)
        #self.leaky1.mem = self.leaky1.mem.to(torch.device("cuda"))

        self.conv2 = nn.Conv2d(3, 64, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(64)
        self.leaky2 = snn.Leaky(beta=beta)
        #self.leaky2.mem = self.leaky2.mem.to(torch.device("cuda"))
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        self.bn3 = nn.BatchNorm2d(128)
        self.leaky3 = snn.Leaky(beta=beta)
        #self.leaky3.mem = self.leaky3.mem.to(torch.device("cuda"))

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3)
        self.bn4 = nn.BatchNorm2d(128)
        self.leaky4 = snn.Leaky(beta=beta)
        #self.leaky4.mem = self.leaky4.mem.to(torch.device("cuda"))

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3)
        self.bn5 = nn.BatchNorm2d(256)
        self.leaky5 = snn.Leaky(beta=beta)
        #self.leaky5.mem = self.leaky5.mem.to(torch.device("cuda"))

        self.conv6 = nn.Conv2d(256, 256, kernel_size=3)
        self.bn6 = nn.BatchNorm2d(256)
        self.leaky6 = snn.Leaky(beta=beta)
        #self.leaky6.mem = self.leaky6.mem.to(torch.device("cuda"))

        self.conv7 = nn.Conv2d(256, 256, kernel_size=3)
        self.bn7 = nn.BatchNorm2d(256)
        self.leaky7 = snn.Leaky(beta=beta)
        #self.leaky7.mem = self.leaky7.mem.to(torch.device("cuda"))

        self.lin1 = nn.Linear(256*7*7, 4096)
        self.bn1d1 = nn.BatchNorm1d(4096)
        self.leaky14 = snn.Leaky(beta=beta)
        #self.leaky14.mem = self.leaky14.mem.to(torch.device("cuda"))
        self.do = nn.Dropout(p = 0.2)
        self.lin2 = nn.Linear(4096, 2048)
        self.bn1d2 = nn.BatchNorm1d(4096)
        self.leaky15 = snn.Leaky(beta=beta)
        #self.leaky15.mem = self.leaky15.mem.to(torch.device("cuda"))
        self.do2 = nn.Dropout(p = 0.2)
        self.lin3 = nn.Linear(2048, num_classes)

        self.leaky16 = snn.Leaky(beta=beta)
        #self.leaky16.mem = self.leaky16.mem.to(torch.device("cuda"))

    def forward(self, x, time_steps, ratio, epoch):
       
        self.leaky1.reset_mem()
        self.leaky2.reset_mem()
        self.leaky3.reset_mem()
        self.leaky4.reset_mem()
        self.leaky5.reset_mem()
        self.leaky6.reset_mem()
        self.leaky7.reset_mem()

        self.leaky14.reset_mem()
        self.leaky15.reset_mem()
        self.leaky16.reset_mem()

        spk_recording = []
        xOrig = x

        for t in range(time_steps):
            start1 = time.time()
            b1 = self.bn1(self.conv1(xOrig))
            l1 = self.leaky1(b1)[0]#, m1
            

            b2 = self.bn2(self.conv2(l1))
            l2 = self.leaky2(b2)[0]#, m2
            
            b3 = self.bn3(self.conv3(l2))
            l3 = self.leaky3(b3)[0] #, m3

            b4 = self.bn4(self.conv4(l3))
            l4 = self.leaky4(b4)[0]#, m4
            

            b5 = self.bn5(self.conv5(l4))
            l5 = self.leaky5(b5)[0]#, m5

            b6 = self.bn6(self.conv6(l5))
            l6 = self.leaky6(b6)[0]#, m6

            b7 = self.bn7(self.conv7(l6))
            p1 = self.pool3(b7)
            l7 = self.leaky7(p1)[0]#, m7

            f1 = torch.flatten(l7, 1)
            
            fc1 = self.lin1(f1)

            l14 = self.leaky14(fc1)[0]#, m14

            fc2 = self.lin2(l14)

            l15 = self.leaky15(fc2)[0]#, m15

            fc3 = self.lin3(l15)
            l16 = self.leaky16(fc3)[0]#, m16
            spk_recording.append(l16)
        return torch.stack(spk_recording)
batch_size = 64
num_epochs = 10
time_steps = 5
beta = 0.75
transform = torchvision.transforms.Compose([
  torchvision.transforms.RandomHorizontalFlip(p=0.3),
  torchvision.transforms.RandomVerticalFlip(p=0.3),
  torchvision.transforms.ToTensor()
  ,])
train = torchvision.datasets.FashionMNIST(root='/content', transform=transform, download=True)

#train_ds = torch.utils.data.Subset(train, torch.arange(0, 10000))
test = torchvision.datasets.FashionMNIST(root='/content', train=False, transform=torchvision.transforms.ToTensor(), download=True)
#test_ds = torch.utils.data.Subset(train, torch.arange(1000, 2000))
data_loader = DataLoader(train, batch_size=batch_size, shuffle=True)#, num_workers = 8
test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)

spk_rec_final = []
loss_hist = []
acc_hist = []

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") #torch.device("cuda") if torch.cuda.is_available() else
model = VGG16_SNN(beta, 10).to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=2.47e-4, momentum=0.9) #lr=2.5e-4 if no weight decay
loss_fn =  SF.ce_rate_loss()

for epoch in range(num_epochs):

  print(f"Starting epoch number {epoch}")
  counter = 0
  torch.cuda.empty_cache()
 
  for bitmap, target in iter(data_loader):

      bitmap = bitmap.to(device)
      target = target.to(device)
      model.train()


      spk_rec = model(bitmap, time_steps, ratio, epoch)
      spk_rec_final.append(spk_rec)

      loss_val = loss_fn(spk_rec, target)
      
      optimizer.zero_grad()
      loss_val.backward()

      optimizer.step()
      loss_hist.append(loss_val.item())

  if epoch == num_epochs -1:
      print('calcing accuracy')
      with torch.no_grad():
        model.eval()
        acc_train = batch_accuracy(data_loader, model, time_steps, device, ratio)
        acc_test = batch_accuracy(test_loader, model, time_steps, device, ratio)
        print(f"Iteration {epoch}, Train Acc: {acc_train * 100:.2f}%\n")
        print(f"Iteration {epoch}, Test Acc: {acc_test * 100:.2f}%\n")
        acc_hist.append(acc_test.item())
        break

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 54
     50 model.train()
     51 #print(torch.unsqueeze(bitmap, dim=1).shape)
---> 54 spk_rec = model(bitmap, time_steps, ratio, epoch)
     55 #print('time to get spks', time.time() - start)
     56 spk_rec_final.append(spk_rec)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

Cell In[4], line 152, in VGG16_SNN.forward(self, x, time_steps, ratio, epoch)
    149 b1 = self.bn1(self.conv1(xOrig))
    150 #print('time to convolve and normalize', time.time() -  start1)
    151 #print(len(torch.where(x == 1)[0]))
--> 152 l1 = self.leaky1(b1)[0]#, m1
    153 #print(l1.shape)
    154
    155 
    156 #print('time to run through first block', time.time() -  start1)
    157 #start = time.time()
    158 
    159 #print(len(torch.where(x == 1)[0]))

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.10/dist-packages/snntorch/_neurons/leaky.py:208, in Leaky.forward(self, input_, mem)
    205 if not self.mem.shape == input_.shape:
    206     self.mem = torch.zeros_like(input_, device=self.mem.device)
--> 208 self.reset = self.mem_reset(self.mem)
    209 self.mem = self.state_function(input_)
    211 if self.state_quant:

File /usr/local/lib/python3.10/dist-packages/snntorch/_neurons/neurons.py:105, in SpikingNeuron.mem_reset(self, mem)
    102 def mem_reset(self, mem):
    103     """Generates detached reset signal if mem > threshold.
    104     Returns reset."""
--> 105     mem_shift = mem - self.threshold
    106     reset = self.spike_grad(mem_shift).clone().detach()
    108     return reset

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
@xiziqiao
Copy link

xiziqiao commented Apr 19, 2024

I got the same issue here! I think it is a library bug
My solution is to downgrade it to 0.8.0.

@SSBakh07
Copy link

Apparently, this issue has come up before (#225) and a workaround was described there, but I was able to temporarily fix it by downgrading to version 0.8.1.

pip install snntorch==0.8.1

@gekkom
Copy link
Contributor

gekkom commented Apr 21, 2024

I will take a look at this, in the meanwhile you can also fix this with.

torch.set_default_device("cuda")

@morenzoe
Copy link

morenzoe commented Apr 23, 2024

I still find this error when running the training loop without population coding in Advanced Tutorials: Population Coding. Setting default device to cuda did not work, but downgrading to 0.8.1 did the job. I guess it's because of the deprecation of snntorch.backprop module.

@jeshraghian
Copy link
Owner

jeshraghian commented Apr 23, 2024

Have you tried installing snntorch from the source rather than pip?

@morenzoe
Copy link

I am trying to do it in Colab now. However another error comes out, ModuleNotFoundError: No module named 'nir', even though the module was there when I checked with !pip show. Does the setup.py in snnTorch only installing the module locally in the snnTorch folder path? Sorry for asking out of topic, some help will be much appreciated!

@jeshraghian
Copy link
Owner

Ah I run into the same error, but it fixed when I restarted my run time... in any case, I'll update the pypi today or tomorrow. That'll hopefully fix everything.

@morenzoe
Copy link

I am finally able to run both of the tutorial in Colab by installing and importing nir and nirtorch first before installing snntorch from the source. Nevertheless, updating the pypi will be a great help. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants