<div style="line-height:0.5">
<h1 style="color:#BF66F2 ">  Tensors in PyTorch 3 </h1>
<h4> Indexing, Unsqueezing, and Splitting. </h4> 
<div style="margin-top: 5px;">
<span style="display: inline-block;">
    <h3 style="color: lightblue; display: inline;">Keywords:</h3>
    warnings.catch_warnings + torch.any() + torch.all() + torch.arange() + views + torch.cat()
</span>
</div>
</div>

In [2]:
import torch
import warnings
import numpy as np

In [3]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    device = "cuda" if torch.cuda.is_available() else "cpu"  
    device

In [3]:
# Initialize a 2x3 Tensor (2 rows, 3 columns)
my_tensor = torch.tensor(
    [[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device=device, requires_grad=True)

In [4]:
print(f"Information about tensor: {my_tensor}")  
print("Type of Tensor {my_tensor.dtype}")
print(f"Device Tensor is on {my_tensor.device}")
print(f"Shape of tensor {my_tensor.shape}")
print(f"Requires gradient: {my_tensor.requires_grad}")

Information about tensor: tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
Type of Tensor {my_tensor.dtype}
Device Tensor is on cpu
Shape of tensor torch.Size([2, 3])
Requires gradient: True


<h3 style="color:#BF66F2 ">  Create tensors in various way </h3>

In [5]:
print("> using zeros")
print(torch.zeros((3, 3)))
print("> using rand")
print(torch.rand((3, 3)))
print("> using ones")
print(torch.ones((3, 3)))
print("> using eye")
print(torch.eye(5, 5))
print("> using arange")
print(torch.arange(start=0, end=5, step=1))
print("> using linspace")
print(torch.linspace(start=0.1, end=1, steps=10))
print("> using empty")
print(torch.empty(size=(3, 3)))
print("> using empty -normal")
print(torch.empty(size=(1, 5)).normal_(mean=0, std=1))
print("> using empty -uniform")
print(torch.empty(size=(1, 5)).uniform_(0, 1))
print("> using diag")
print(torch.diag(torch.ones(3)))

> using zeros
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
> using rand
tensor([[0.8887, 0.0859, 0.6226],
        [0.0683, 0.1083, 0.5266],
        [0.3486, 0.7490, 0.9730]])
> using ones
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
> using eye
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])
> using arange
tensor([0, 1, 2, 3, 4])
> using linspace
tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000, 0.9000,
        1.0000])
> using empty
tensor([[9.5423e-35, 4.5890e-41, 1.9274e-34],
        [0.0000e+00, 4.4842e-44, 0.0000e+00],
        [8.9683e-44, 0.0000e+00, 1.9235e-34]])
> using empty -normal
tensor([[-1.0812, -2.7225,  0.6885, -1.8093, -0.4509]])
> using empty -uniform
tensor([[0.9151, 0.7595, 0.2144, 0.3425, 0.4196]])
> using diag
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])


In [6]:
tensor = torch.arange(4)  
print(f"Starting tensor = {tensor}")
print()
print(f"Converted Boolean: {tensor.bool()}")  
print(f"Converted int16 {tensor.short()}") 
print(f"Converted int64 {tensor.long()}")  
print(f"Converted float16 {tensor.half()}")
print(f"Converted float32 {tensor.float()}")
print(f"Converted float64 {tensor.double()}") 


np_array = np.zeros((5, 5))
tensor = torch.from_numpy(np_array)
np_array_again = (tensor.numpy()) 

np_array_again

Starting tensor = tensor([0, 1, 2, 3])

Converted Boolean: tensor([False,  True,  True,  True])
Converted int16 tensor([0, 1, 2, 3], dtype=torch.int16)
Converted int64 tensor([0, 1, 2, 3])
Converted float16 tensor([0., 1., 2., 3.], dtype=torch.float16)
Converted float32 tensor([0., 1., 2., 3.])
Converted float64 tensor([0., 1., 2., 3.], dtype=torch.float64)


array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]])

<h2 style="color:#BF66F2 ">  => Tensor Math & Comparison: </h2>

In [7]:
x = torch.tensor([13, 62, 4])
y = torch.tensor([91, 39, 31])

x,y

(tensor([13, 62,  4]), tensor([91, 39, 31]))

In [8]:
""" All operations followed by _ will mutate the tensor in place. """

#### Addition
zadd1 = torch.empty(3)
torch.add(x, y, out=zadd1)  
zadd2 = torch.add(x, y)  
zadd3 = x + y  
# Subtraction
zsub = x - y 
# Division (element-wise division if of equal shape)
zdiv = torch.true_divide(x, y)  

# Inplace Operations
t = torch.zeros(3)

t.add_(x)  
t += x 

## Exponentiation 
zexp1 = x.pow(2)
zexp2 = x**2

print(f"zadd1 {zadd1}")
print(f"zadd2 {zadd1}")
print(f"zadd3 {zadd1}")
print(f"zsub {zsub}")
print(f"zdiv {zdiv}")
print(f"t {t}")
print(f"zexp1 {zexp1}")
print(f"zexp2 {zexp2}")

zadd1 tensor([104., 101.,  35.])
zadd2 tensor([104., 101.,  35.])
zadd3 tensor([104., 101.,  35.])
zsub tensor([-78,  23, -27])
zdiv tensor([0.1429, 1.5897, 0.1290])
t tensor([ 26., 124.,   8.])
zexp1 tensor([ 169, 3844,   16])
zexp2 tensor([ 169, 3844,   16])


In [9]:
### Simple Comparison
istru1 = x > 0  
istru2 = x < 0  
istru1, istru2

(tensor([True, True, True]), tensor([False, False, False]))

In [10]:
""" Matrix Multiplication """
x1 = torch.rand((2, 5))
x2 = torch.rand((5, 3))
x3 = torch.mm(x1, x2)
x3 = x1.mm(x2)  

x1,x2,x3

(tensor([[0.9988, 0.2362, 0.8945, 0.2892, 0.1626],
         [0.6635, 0.5766, 0.1184, 0.2094, 0.3850]]),
 tensor([[0.2388, 0.8636, 0.6486],
         [0.1661, 0.4776, 0.1657],
         [0.4075, 0.8635, 0.0388],
         [0.1257, 0.6901, 0.5808],
         [0.0402, 0.9301, 0.4934]]),
 tensor([[0.6852, 2.0986, 0.9699],
         [0.3443, 1.4531, 0.8421]]))

In [11]:
# Matrix Exponentiation
matrix_exp = torch.rand(5, 5)
print(matrix_exp.matrix_power(3))  

# Element-wise Multiplication
z = x * y  
print(f"Elementwise {z}")

# Dot product
dot = torch.dot(x, y)
print(f"\n Dot product {dot}")

tensor([[3.8089, 2.5563, 2.2565, 2.7977, 2.8177],
        [5.0540, 3.3086, 2.9690, 3.6719, 3.7327],
        [2.9356, 1.9235, 1.7310, 2.0870, 2.1043],
        [3.1672, 2.1783, 1.9301, 2.1526, 2.2150],
        [3.7691, 2.4321, 2.2943, 2.6194, 2.6318]])
Element wise tensor([1183, 2418,  124])

 Dot product 3725


In [12]:
""" Batch Matrix Multiplication => between the last two dimensions of each tensor. """
batch = 32
n = 10
m = 20
p = 30
tensor1 = torch.rand((batch, n, m))
tensor2 = torch.rand((batch, m, p))
out_bmm = torch.bmm(tensor1, tensor2)

print(tensor1.shape)
tensor1[:2]

torch.Size([32, 10, 20])


tensor([[[8.4261e-01, 3.0092e-01, 7.8176e-01, 3.1163e-01, 1.4427e-01,
          8.6063e-01, 8.2137e-02, 1.9298e-01, 2.5813e-01, 3.3769e-01,
          5.7559e-01, 4.0910e-01, 1.1051e-03, 3.7224e-01, 1.2842e-01,
          7.3466e-01, 3.4883e-01, 4.9214e-02, 1.4183e-01, 2.4460e-01],
         [9.9859e-01, 3.9346e-01, 5.7881e-01, 1.3190e-01, 2.1938e-01,
          4.5551e-01, 3.1521e-01, 9.6823e-01, 7.2722e-02, 1.9922e-01,
          1.5762e-01, 1.3863e-01, 5.7842e-02, 8.7356e-01, 1.5956e-01,
          4.9559e-01, 6.3387e-02, 2.3631e-01, 7.9616e-01, 6.0165e-01],
         [2.7927e-01, 3.8414e-01, 9.5781e-01, 5.7902e-02, 9.0150e-01,
          1.5924e-01, 7.2419e-01, 5.3276e-01, 9.0321e-01, 7.0979e-01,
          7.9326e-02, 8.1351e-01, 9.2070e-02, 4.8113e-01, 3.4554e-01,
          8.4606e-01, 6.1683e-01, 4.8480e-01, 8.4794e-02, 6.3852e-01],
         [6.0609e-01, 7.2384e-02, 8.4133e-01, 8.2729e-01, 8.5949e-01,
          7.1178e-01, 6.7081e-01, 1.1893e-01, 3.8066e-01, 1.2461e-01,
          9.7887e

In [13]:
print(out_bmm[:2])

tensor([[[3.4172, 3.8610, 3.4128, 4.0864, 4.3969, 3.2427, 2.4913, 5.0018,
          4.0682, 3.9401, 3.6040, 4.2782, 3.8998, 3.8991, 3.5839, 3.6537,
          4.2483, 2.9701, 3.1290, 4.5038, 3.4652, 3.3697, 2.9554, 3.7689,
          3.4784, 2.9475, 3.5563, 2.7185, 3.1320, 3.1868],
         [2.8471, 4.6272, 3.6242, 4.8018, 4.5335, 3.2338, 3.8616, 4.8141,
          3.8052, 4.2151, 4.9653, 5.0533, 4.2333, 4.2066, 3.7224, 4.5597,
          4.1895, 3.0118, 3.3971, 4.6309, 4.1493, 2.9031, 2.6158, 4.5307,
          4.1657, 3.4376, 4.7756, 3.0895, 3.2068, 3.5410],
         [4.1296, 6.0509, 5.8640, 5.4673, 5.2855, 5.6392, 3.7974, 6.7183,
          6.2356, 5.0796, 5.1835, 6.2738, 5.1677, 5.9051, 5.3783, 5.3138,
          5.4212, 4.4006, 4.5818, 5.8626, 3.9073, 4.2488, 4.6746, 6.3730,
          5.5211, 5.4732, 5.3882, 4.4101, 4.1686, 4.4311],
         [5.7404, 6.6648, 6.0872, 6.6425, 7.3778, 5.2485, 4.3733, 8.0382,
          5.8612, 6.1575, 5.8222, 6.4290, 6.2461, 6.6716, 5.8081, 5.8750,
         

In [14]:
""" Broadcasting """
x1 = torch.rand((5, 5))
x2 = torch.ones((1, 5))
bd = (x1 - x2)  
# Element-wise exponentiation
be = (x1**x2)  

bd.shape, be.shape 

(torch.Size([5, 5]), torch.Size([5, 5]))

<h2 style="color:#BF66F2 ">  => Tensor operations: </h2>

In [15]:
sum_x = torch.sum(x, dim=0)

tx = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
sum_tx = torch.sum(tx, dim=0)

print(x)
print(sum_x)
print(tx)
print(sum_tx)

tensor([13, 62,  4])
tensor(79)
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])
tensor([[ 6,  8],
        [10, 12]])


In [16]:
x.max(dim=0)
#or ...
#values, indices = torch.max(x, dim=0)

torch.return_types.max(
values=tensor(62),
indices=tensor(1))

In [17]:
x.min(dim=0) 
#or ...
#values, indices = torch.min(x, dim=0)  

torch.return_types.min(
values=tensor(4),
indices=tensor(2))

In [18]:
abs_x = torch.abs(x)

In [19]:
# Index of the max / min value
in1 = torch.argmax(x, dim=0)  
in2 = torch.argmin(x, dim=0) 

In [20]:
mean_x = torch.mean(x.float(), dim=0)  #x must be float

In [21]:
# Elementwise comparison
comp = torch.eq(x, y)

In [22]:
tt = torch.tensor([1, 0, 1, 1, 1], dtype=torch.bool)
tt

tensor([ True, False,  True,  True,  True])

In [23]:
sorted_y, indices = torch.sort(y, dim=0, descending=False)
sorted_y

tensor([31, 39, 91])

In [24]:
cla = torch.clamp(tt, min=0)
cla

tensor([1, 0, 1, 1, 1])

In [25]:
""" Any and All """
anyx = torch.any(x)  # any element is True or non-zero?
allx = torch.all(x)  # all elements are True or non-zero?
anyx, allx

(tensor(True), tensor(True))

<h2 style="color:#BF66F2 ">  => Tensor Indexing </h2>

In [26]:
batch_size = 10
features = 25
te = torch.rand((batch_size, features))
te

tensor([[0.5357, 0.2601, 0.3347, 0.5389, 0.9592, 0.0840, 0.5363, 0.5458, 0.7332,
         0.5033, 0.4220, 0.7546, 0.8533, 0.8763, 0.2569, 0.8281, 0.9374, 0.5556,
         0.6928, 0.6930, 0.7050, 0.3199, 0.1857, 0.3942, 0.0534],
        [0.0603, 0.8431, 0.8331, 0.7032, 0.1361, 0.1627, 0.0261, 0.3157, 0.4291,
         0.6073, 0.2645, 0.4724, 0.8195, 0.9654, 0.8667, 0.0721, 0.8472, 0.9209,
         0.7602, 0.8365, 0.5832, 0.5591, 0.7055, 0.7231, 0.6286],
        [0.1058, 0.6518, 0.6896, 0.3820, 0.8234, 0.9171, 0.4418, 0.5996, 0.6889,
         0.0093, 0.3729, 0.5428, 0.0521, 0.3721, 0.4343, 0.7559, 0.4039, 0.6153,
         0.3644, 0.6034, 0.7974, 0.6528, 0.6158, 0.6581, 0.6424],
        [0.5645, 0.1624, 0.6306, 0.5633, 0.9614, 0.1309, 0.0600, 0.8707, 0.1207,
         0.3257, 0.6645, 0.6538, 0.6423, 0.7665, 0.1972, 0.1321, 0.6533, 0.1512,
         0.4793, 0.3954, 0.4870, 0.9543, 0.8805, 0.5194, 0.9795],
        [0.2722, 0.3399, 0.0431, 0.1068, 0.3334, 0.8692, 0.5377, 0.4068, 0.6149,
       

In [27]:
print(te[0].shape)
print(te[:, 0].shape)
print(te[2, 0:10].shape)

torch.Size([25])
torch.Size([10])
torch.Size([10])


In [28]:
ttt = torch.arange(10)
indices = [2, 5, 8]
print(ttt[indices])

tensor([2, 5, 8])


In [29]:
x = torch.rand((3, 5))
rows = torch.tensor([1, 0])
cols = torch.tensor([4, 0])
x, rows, cols

(tensor([[0.1364, 0.5146, 0.8312, 0.9020, 0.8082],
         [0.5645, 0.9798, 0.1588, 0.0601, 0.6598],
         [0.9875, 0.0044, 0.0702, 0.1203, 0.3766]]),
 tensor([1, 0]),
 tensor([4, 0]))

In [30]:
print(x[rows, cols])    # second row/fifth column elem + first row/first column elem => row 1 and 0 + column 4 0 

tensor([0.6598, 0.1364])


In [31]:
x = torch.arange(10)
print(x)
print(x[(x < 2) | (x > 8)]) 
print(x[x.remainder(2) == 0])

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 1, 9])
tensor([0, 2, 4, 6, 8])


In [32]:
torch.where(x > 5, x, x * 2)

tensor([ 0,  2,  4,  6,  8, 10,  6,  7,  8,  9])

In [33]:
tor = torch.tensor([0, 0, 1, 2, 2, 3, 4]).unique()
print("num of dimensions of tor is ", tor.ndimension())
print("num of element of tor is ", tor.numel()) 
tor

num of dimensions of tor is  1
num of element of tor is  5


tensor([0, 1, 2, 3, 4])

<h2 style="color:#BF66F2 ">  => Tensor Changing shape </h2>

In [34]:
""" Reshape: => View and reshape do the same action ...
But the latter act on contiguous tensors => meaning that if the tensor it is not stored contiguously in memory 
it makes a copy of the tensor to make it contiguously stored => leading to performance loss in general in general !
"""
x = torch.arange(9)
x_3x3_v = x.view(3, 3)
x_3x3_r = x.reshape(3, 3)
x_3x3_v, x_3x3_r

(tensor([[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]),
 tensor([[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]))

<h3 style="color:#BF66F2"> Common errors: </h3>
<div style="margin-top: -17px;">
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). <br>
.reshape(...) should be used instead. <br>

PyTorch is unable to create a new tensor with the desired shape because the memory layout of y is not contiguous. <br>
Now the tensor is stored in memory as it was stored [0, 1, 2, ... 8], whereas now it's [0, 3, 6, 1, 4, 7, 2, 5, 8] <br>
</div>

<h3 style="color:#BF66F2"> Recap: </h3>
<div style="margin-top: -22px;">
A tensor is said to be contiguous if its data is stored in a contiguous block of memory,  <br>
which means that the elements of the tensor are stored in a linear sequence without any gaps or padding. <br>
</div>

In [35]:
""" Check continuity. 
After y.t(), PyTorch creates a new tensor y that is a transposed version of x_3x3_v. 
However, the transpose operation changes the memory layout of the tensor, which means that the elements 
of y are no longer stored in a contiguous block of memory.
"""
y = x_3x3_v.t()
print("is contiguous?")
print(y.is_contiguous())
#y.view(9) Error! But calling .contiguous() before view and it works
print(y.contiguous().view(9)) 

is contiguous?
False
tensor([0, 3, 6, 1, 4, 7, 2, 5, 8])


In [36]:
x1 = torch.rand(2, 5)
x2 = torch.rand(2, 5)

cat1 = torch.cat((x1, x2), dim=0)
cat2 = torch.cat((x1, x2), dim=1)

print(x1)
print(x2)
print()
print(cat1)
print()
print(cat2)

tensor([[0.2886, 0.2156, 0.0916, 0.1131, 0.9295],
        [0.0850, 0.1922, 0.7210, 0.6702, 0.7830]])
tensor([[0.4782, 0.6378, 0.9850, 0.7037, 0.0858],
        [0.6427, 0.9730, 0.6960, 0.3944, 0.1554]])

tensor([[0.2886, 0.2156, 0.0916, 0.1131, 0.9295],
        [0.0850, 0.1922, 0.7210, 0.6702, 0.7830],
        [0.4782, 0.6378, 0.9850, 0.7037, 0.0858],
        [0.6427, 0.9730, 0.6960, 0.3944, 0.1554]])

tensor([[0.2886, 0.2156, 0.0916, 0.1131, 0.9295, 0.4782, 0.6378, 0.9850, 0.7037,
         0.0858],
        [0.0850, 0.1922, 0.7210, 0.6702, 0.7830, 0.6427, 0.9730, 0.6960, 0.3944,
         0.1554]])


In [37]:
# Unroll x1 into one long vector with 10 elements
unro1 = x1.view(-1)
unro1

tensor([0.2886, 0.2156, 0.0916, 0.1131, 0.9295, 0.0850, 0.1922, 0.7210, 0.6702,
        0.7830])

In [43]:
batch = 64
x = torch.rand((batch, 2, 5))
unro2 = x.view(batch, -1)

x[:2], unro2[:2]

(tensor([[[0.2106, 0.9901, 0.5215, 0.7674, 0.7645],
          [0.1227, 0.5335, 0.0043, 0.4914, 0.0064]],
 
         [[0.1640, 0.6757, 0.2341, 0.1470, 0.9348],
          [0.7159, 0.4676, 0.4251, 0.5388, 0.2570]]]),
 tensor([[0.2106, 0.9901, 0.5215, 0.7674, 0.7645, 0.1227, 0.5335, 0.0043, 0.4914,
          0.0064],
         [0.1640, 0.6757, 0.2341, 0.1470, 0.9348, 0.7159, 0.4676, 0.4251, 0.5388,
          0.2570]]))

In [44]:
# Permute
perm = x.permute(0, 2, 1)
perm[:5]

tensor([[[0.2106, 0.1227],
         [0.9901, 0.5335],
         [0.5215, 0.0043],
         [0.7674, 0.4914],
         [0.7645, 0.0064]],

        [[0.1640, 0.7159],
         [0.6757, 0.4676],
         [0.2341, 0.4251],
         [0.1470, 0.5388],
         [0.9348, 0.2570]],

        [[0.7083, 0.6712],
         [0.2961, 0.4143],
         [0.3482, 0.4635],
         [0.7602, 0.1377],
         [0.6553, 0.8822]],

        [[0.5195, 0.9047],
         [0.9322, 0.2550],
         [0.8232, 0.9111],
         [0.6901, 0.3973],
         [0.3389, 0.2164]],

        [[0.7395, 0.3150],
         [0.0132, 0.8828],
         [0.5130, 0.3999],
         [0.0814, 0.1520],
         [0.3744, 0.1290]]])

In [55]:
""" Split x last dimension into chunks of 2 the last dimension """ 
chu = torch.chunk(x, chunks=2, dim=1)
print(chu[0].shape)
print(chu[1].shape)

chu[1:2:3]

torch.Size([64, 1, 5])
torch.Size([64, 1, 5])


(tensor([[[0.1227, 0.5335, 0.0043, 0.4914, 0.0064]],
 
         [[0.7159, 0.4676, 0.4251, 0.5388, 0.2570]],
 
         [[0.6712, 0.4143, 0.4635, 0.1377, 0.8822]],
 
         [[0.9047, 0.2550, 0.9111, 0.3973, 0.2164]],
 
         [[0.3150, 0.8828, 0.3999, 0.1520, 0.1290]],
 
         [[0.7984, 0.0282, 0.2919, 0.1347, 0.6902]],
 
         [[0.4025, 0.3768, 0.5137, 0.3990, 0.3248]],
 
         [[0.9057, 0.0207, 0.1701, 0.5858, 0.1975]],
 
         [[0.4497, 0.0796, 0.7945, 0.2010, 0.3226]],
 
         [[0.3809, 0.0394, 0.8266, 0.1538, 0.5597]],
 
         [[0.4836, 0.8849, 0.2649, 0.9121, 0.0803]],
 
         [[0.3814, 0.3743, 0.3280, 0.2800, 0.1398]],
 
         [[0.6580, 0.0643, 0.6374, 0.4317, 0.0315]],
 
         [[0.2898, 0.1485, 0.9488, 0.1526, 0.0411]],
 
         [[0.4615, 0.2853, 0.0089, 0.9115, 0.6696]],
 
         [[0.4493, 0.2971, 0.9602, 0.8934, 0.2660]],
 
         [[0.6981, 0.3945, 0.1486, 0.4506, 0.5218]],
 
         [[0.2006, 0.7963, 0.6755, 0.4820, 0.8380]],
 
         [

<h3 style="color:#BF66F2"> Recap: Unsqueeze</h3>
<div style="margin-top: -22px;">
Adding a new dimension to a tensor (inverse of squeeze)! 

- It adds a dimension of size 1 at the position given by the dim argument. <br>
- The number of dimensions of the tensor increases by 1 after applying unsqueeze. <br>
- The new dimension is inserted at the location specified by dim. All subsequent dimensions are shifted. <br>
</div>

In [41]:
""" Unsqueeze """ 
x = torch.arange(10)
print(x.unsqueeze(0).shape)  
print(x.unsqueeze(1).shape)  
print()

x, x.unsqueeze(0), x.unsqueeze(1)

torch.Size([1, 10])
torch.Size([10, 1])



(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
 tensor([[0],
         [1],
         [2],
         [3],
         [4],
         [5],
         [6],
         [7],
         [8],
         [9]]))

In [42]:
torx = torch.arange(10).unsqueeze(0).unsqueeze(1)
sqsq = torx.squeeze(1)
torx, sqsq

(tensor([[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]),
 tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]))