#!/usr/bin/env python # coding: utf-8 import os gpu_ids = "5,6" os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"]= gpu_ids device_ids = [i for i,e in enumerate(gpu_ids.split(','))] import torch from torch import nn from cplxmodule import cplx from cplxmodule.nn import CplxToCplx import cplxmodule.nn as cplxnn import torch.nn.functional as F from torchvision import datasets, transforms class cplxtest_net(nn.Module): def __init__(self): super(cplxtest_net, self).__init__() self.conv1 = cplxnn.CplxConv2d(3, 20, 5, 1) def forward(self,z): z = self.conv1(z) return z class test_net(nn.Module): def __init__(self): super(test_net, self).__init__() self.conv1 = nn.Conv2d(3, 20, 5, 1) def forward(self,x): x = self.conv1(x) return x device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") is_parallel = len(device_ids) > 1 model_real = test_net().to(device) model_complex = cplxtest_net().to(device) if is_parallel: model_real = nn.DataParallel(model_real, device_ids=device_ids) model_complex = nn.DataParallel(model_complex, device_ids=device_ids) real_data = torch.randn(1,3,224,224).to(device) complex_data = cplx.Cplx(torch.randn(1,3,224,224),torch.randn(1,3,224,224)).to(device) real_output = model_real(real_data) assert real_output.shape == torch.Size([1, 20, 220, 220]) complex_output = model_complex(complex_data) assert complex_output.shape == torch.Size([1, 20, 220, 220])