# Instances where we use:     y = x · A^T + b

# Gene expression analysis: Transforming gene expression data

In [None]:
import torch
import numpy as np

# Simulated gene expression data
num_samples = 100
num_genes = 1000
gene_expression = torch.randn(num_samples, num_genes)

# Linear transformation for gene expression analysis
output_dim = 50
A = torch.randn(output_dim, num_genes)
b = torch.randn(output_dim)

# Transform gene expression data
transformed_expression = torch.matmul(gene_expression, A.t()) + b

print("Original shape:", gene_expression.shape)
print("Transformed shape:", transformed_expression.shape)

# Example analysis: Find genes with highest average expression
avg_expression = torch.mean(transformed_expression, dim=0)
top_genes = torch.argsort(avg_expression, descending=True)[:10]
print("Top 10 gene indices after transformation:", top_genes.tolist())

# Protein Structure Prediction:

In [None]:
import torch

# Simulated protein features (e.g., amino acid properties)
num_residues = 200
feature_dim = 20
protein_features = torch.randn(1, num_residues, feature_dim)

# Linear transformation for structure prediction
output_dim = 3  # 3D coordinates
A = torch.randn(output_dim, feature_dim)
b = torch.randn(output_dim)

# Predict 3D coordinates
predicted_structure = torch.matmul(protein_features, A.t()) + b

print("Protein features shape:", protein_features.shape)
print("Predicted structure shape:", predicted_structure.shape)

# Visualize first few predicted coordinates
print("First 5 predicted 3D coordinates:")
print(predicted_structure[0, :5, :].numpy())

# Sequence Analysis (DNA encoding):

In [None]:
import torch

# DNA sequence
dna_seq = "ATCGATCGATCG"

# One-hot encoding
nucleotide_to_index = {'A': 0, 'T': 1, 'C': 2, 'G': 3}
one_hot = torch.zeros(len(dna_seq), 4)
for i, nucleotide in enumerate(dna_seq):
    one_hot[i, nucleotide_to_index[nucleotide]] = 1

# Linear transformation for sequence analysis
output_dim = 8
A = torch.randn(output_dim, 4)
b = torch.randn(output_dim)

# Transform encoded sequence
transformed_seq = torch.matmul(one_hot, A.t()) + b

print("One-hot encoded shape:", one_hot.shape)
print("Transformed sequence shape:", transformed_seq.shape)

# Example analysis: Find position with highest transformed value
max_pos = torch.argmax(torch.max(transformed_seq, dim=1)[0])
print("Position with highest transformed value:", max_pos.item())

# Dimensionality Reduction:

Explanation of what happens during dimensionality reduction:

Original Data:

We start with high-dimensional data (100 dimensions in this case).
Each data point is represented by 100 features, which is difficult to visualize directly.
The first plot shows only the first 3 dimensions of this 100-dimensional space.


Dimensionality Reduction:

We reduce the 100-dimensional data to 2 dimensions.
This process attempts to preserve the most important information or patterns in the data.


Our Linear Transformation Method:

We use the equation y = x · A^T + b to transform the data.
This method projects the high-dimensional data onto a 2D plane.
The resulting plot shows how the data points are distributed in this new 2D space.
However, this random linear transformation may not optimally preserve the data structure.


PCA (Principal Component Analysis) Method:

PCA finds the directions (principal components) of maximum variance in the data.
It then projects the data onto these principal components.
The resulting plot shows the data distributed along the two most significant principal components.
PCA is often more effective at preserving the overall structure of the data.

We added plt.ion() at the beginning to enable interactive mode. We removed the plt.savefig() calls and kept the plt.show() calls. We added plt.ioff() and a final plt.show() at the end to keep the plot windows open.

When you run this script in PyCharm:

The plots should appear in the "SciView" tool window (usually on the right side of the PyCharm window).
If the plots don't appear automatically, you might need to click on the "Python Scientific" tab in the tool window.
You can interact with the plots, zoom in/out, and pan around.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

# Enable interactive mode for matplotlib
plt.ion()

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Simulated high-dimensional biological data
num_samples = 1000
original_dim = 100
data = torch.randn(num_samples, original_dim)

# Linear transformation for dimensionality reduction
reduced_dim = 2
A = torch.randn(reduced_dim, original_dim)
b = torch.randn(reduced_dim)

# Reduce dimensionality using our linear transformation
reduced_data = torch.matmul(data, A.t()) + b

print("Original data shape:", data.shape)
print("Reduced data shape:", reduced_data.shape)

# Use PCA for comparison
pca = PCA(n_components=2)
pca_reduced_data = pca.fit_transform(data.numpy())

# Plot 1: Original vs Our Method
fig1, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Original data (first 2 dimensions)
ax1.scatter(data[:, 0].numpy(), data[:, 1].numpy(), alpha=0.5)
ax1.set_title("Original Data (First 2 Dimensions)")
ax1.set_xlabel("Dimension 1")
ax1.set_ylabel("Dimension 2")

# Reduced data (our method)
ax2.scatter(reduced_data[:, 0].numpy(), reduced_data[:, 1].numpy(), alpha=0.5)
ax2.set_title("Reduced Data (Our Method)")
ax2.set_xlabel("Dimension 1")
ax2.set_ylabel("Dimension 2")

plt.tight_layout()
plt.show()

# Plot 2: Original vs PCA
fig2, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Original data (first 2 dimensions)
ax1.scatter(data[:, 0].numpy(), data[:, 1].numpy(), alpha=0.5)
ax1.set_title("Original Data (First 2 Dimensions)")
ax1.set_xlabel("Dimension 1")
ax1.set_ylabel("Dimension 2")

# Reduced data (PCA)
ax2.scatter(pca_reduced_data[:, 0], pca_reduced_data[:, 1], alpha=0.5)
ax2.set_title("Reduced Data (PCA)")
ax2.set_xlabel("Principal Component 1")
ax2.set_ylabel("Principal Component 2")

plt.tight_layout()
plt.show()

print("Explained variance ratio (PCA):", pca.explained_variance_ratio_)

# Keep the plot windows open
plt.ioff()
plt.show()