# DB[RC/S3]
Density-Based Residue Clustering by Dissimilarity Between Sequence SubSets)
#
## Ring-hydroxylating Dioxygenases

In [None]:
from protlearn import *
from sklearn.ensemble import RandomForestRegressor

In [None]:
###====================================================================================================
### Parameters
###====================================================================================================

class Args(object):
    def __init__(self) -> None:
        self.__getattr__ = None
args = Args()
args.file = 'alignment.fasta'
args.expand_alphabet = False
args.min_freq = .25
args.max_dist = 1.
args.min_size = 3
args.out = None

if args.out is None:
    args.out = 'output/' + args.file.split('.')[0]

In [None]:
# #====================================================================================================
msa = MSA()
msa.parse(args.file)
msa.read()

In [None]:
msa.headers

In [None]:
msa.sequences.shape

In [None]:
msa.weights = [1./ float(msa.size) for i in range(msa.size)]
# msa.henikoff()
msa.weights

In [None]:
df_msa = pd.DataFrame(msa.sequences)

# Calculate the proportion of "-" values weighted by row weight
gap_ratio = (df_msa == '-').mul(msa.weights, axis=0).sum()

# Filter the columns based on the condition that "-" is present in more than 90% of the rows
selected_columns = gap_ratio.index[(gap_ratio < args.min_freq)]

# Select only the columns corresponding to the selected features
df_raw = df_msa[selected_columns]
df_raw

In [None]:
"""
Attention!!! This snippet is specific for this particular example so it must be adapted to
each particular case.
"""
df_metadata = pd.read_csv('data.tsv', delimiter='\t').dropna(subset='EC number')

df_metadata['EC number'] = df_metadata['EC number'].apply(lambda x: x.split('; '))
df_metadata

In [None]:
seq_ids = list(map(lambda x: x.split('/')[0].split('_')[0], msa.headers))
seq_idx = list(range(msa.size))
df_indices = pd.DataFrame({'Entry':seq_ids, 'Index':seq_idx})
df_indices

In [None]:
# Join the dataframes based on the 'ID' column
df_indexed = pd.merge(df_metadata, df_indices, on='Entry')
df_merged = pd.merge(df_indexed, df_raw, left_on='Index', right_index=True, how='inner')
df_merged

In [None]:
df_label = df_merged.explode('EC number')
df_label

In [None]:
df_label['EC number'].value_counts()

In [None]:
df_label['Label'] = df_label['EC number'].apply(lambda x: '.'.join(x.split('.')[:-1]))
df_label

In [None]:
df_label['Label'].value_counts()

In [None]:
df_encoded = pd.get_dummies(df_label['Label']).astype(int)
df_encoded

In [None]:
df_chars = df_label[df_raw.columns]
df_chars

In [None]:
def target_mean(df, by, on):
    means = df.groupby(by)[on].mean()
    return df[by].map(means)

In [None]:
ftr_imp = []
for i in df_encoded.columns:
    target = df_encoded[i].tolist()
    df_num = df_chars.copy()
    df_num['Target'] = target
    for j in df_num.columns:
        df_num[j] = target_mean(df_num, by=j, on='Target')

    # Split the dataset into features (X) and target (y)
    X = df_num.drop('Target', axis=1)
    y = df_num['Target']

    # Fit a random forest model to the data
    rf = RandomForestRegressor(random_state=0)
    rf.fit(X, y)
    ftr_imp.append(rf.feature_importances_)
arr = np.array(ftr_imp)
averages = np.mean(arr, axis=0)

averages

In [None]:
# Get the feature importances and sort them in descending order
importances = pd.Series(averages, index=X.columns).sort_values(ascending=False)
importances

In [None]:
# Calculate the cumulative sum of the importance values
cumulative_importance = importances.cumsum()
cumulative_importance

In [None]:
# Filter the feature importances to keep only those that contribute to 99% of the importance
most_important = importances[cumulative_importance <= 0.75].sort_values(ascending=False)
selected_features = importances[cumulative_importance <= 0.75].index
higher_importance = cumulative_importance[selected_features]

fig, ax1 = plt.subplots(figsize=(16, 4))

# Bar chart of percentage importance
xvalues = range(len(most_important))
ax1.bar(xvalues, most_important, color='b')
ax1.set_ylabel('Percentage of total importance')
ax1.tick_params(axis='y')

# Line chart of cumulative percentage importance
ax2 = ax1.twinx()
ax2.plot(xvalues, higher_importance, color='r', marker='.')
ax2.set_ylabel('Cumulative importance')
ax2.tick_params(axis='y')

# Rotate x-axis labels
plt.xticks(xvalues, most_important.index)

# Rotate x-axis labels
plt.setp(ax1.xaxis.get_majorticklabels(), rotation=90)
plt.setp(ax2.xaxis.get_majorticklabels(), rotation=90)

# Adjust layout to make sure labels are visible
# plt.tight_layout()

plt.show()

In [None]:
# Select only the columns corresponding to the selected features
df_selected = df_num[selected_features]
df_selected

In [None]:
# _, axs = plt.subplots(nrows=1, ncols=3, figsize=(12, 4))
#
# # Plot the original DataFrame
# sns.heatmap(X, cmap='coolwarm', xticklabels=False, yticklabels=False, ax=axs[0])
# axs[0].set_title('Original DataFrame')
#
# # Plot the sorted DataFrame
# sns.heatmap(X[importances.index], cmap='coolwarm', xticklabels=False, yticklabels=False, ax=axs[1])
# axs[1].set_title('Sorted DataFrame by\ndescending feature importance')
#
# # Plot the filtered DataFrame
# sns.heatmap(df_selected, cmap='coolwarm', yticklabels=False, ax=axs[2])
# axs[2].set_title('Filtered DataFrame by\ncumulative feature importance')
#
# plt.tight_layout()
# plt.show()

In [None]:
# #====================================================================================================
R = []
for col in most_important.index:
    R += msa.collection[col]

In [None]:
# #====================================================================================================
G = nx.Graph()
for i, a in enumerate(R[:-1]):
    if a.p() >= args.min_freq:
        for b in R[i + 1:]:
            if b.p() >= args.min_freq:
                G.add_edge(
                    a,
                    b,
                    weight = float(
                        sum(
                            map(lambda x: msa.weights[x], a.sequence_indices ^ b.sequence_indices)
                        )
                    ) / float(
                        sum(
                            map(lambda x: msa.weights[x], a.sequence_indices | b.sequence_indices)
                        )
                    )
                )
# #====================================================================================================
N = sorted(G.nodes(), key=lambda x: x.p(), reverse=True)
for n in N:
    print(n)
# #====================================================================================================
D = nx.to_numpy_array(G, nodelist=N)
D

In [None]:
# Plot the distance matrix
fig, ax = plt.subplots()
im = ax.imshow(D, cmap='viridis')

# Add a colorbar
cbar = ax.figure.colorbar(im, ax=ax)

# Show the plot
plt.show()

In [None]:
# #====================================================================================================
optics_instance = optics(D, args.max_dist, args.min_size, None, 'distance_matrix')
optics_instance.process()
clusters = optics_instance.get_clusters()
# #====================================================================================================
ordering = ordering_analyser(optics_instance.get_ordering())
ordering = ordering.cluster_ordering
plt.figure()
plt.bar(range(0, len(ordering)), ordering[0:len(ordering)], width=1., color='black')
plt.xlim([0, len(ordering)])
plt.xlabel('Points')
plt.ylabel('Reachability Distance')
plt.savefig('%s_reachability_plot.png' % args.out)

In [None]:
# #====================================================================================================
clusters = sorted(clusters, key=lambda x: np.mean(list(map(lambda y: N[y].p(), x))), reverse=True)
i = 0
while i < len(clusters):
    positions = set(map(lambda x: N[x].position, clusters[i]))
    same_position = {k: [] for k in positions}
    for j in clusters[i]:
        same_position[N[j].position].append(j)
    temp = []
    c = Subset(msa, list(set.union(*map(lambda x: set(N[x].sequence_indices), clusters[i]))))
    for j in clusters[i]:
        if j == max(same_position[N[j].position], key=lambda x: N[x].p.given(c)):
            temp.append(j)
    if len(temp) >= args.min_size:
        clusters[i] = temp
        i += 1
    else:
        del clusters[i]
clusters

In [None]:
# #====================================================================================================
with open('%s_clusters.csv' % args.out, 'w') as outfile:
    for i in range(len(clusters)):
        outfile.write('Cluster %d\n' % (i + 1))
        d = {'MSA\nColumn': [], 'Feature': [], 'Frequency': []}
        for j in sorted(clusters[i], key=lambda x: N[x].position):
            d['MSA\nColumn'].append(N[j].position + 1)
            d['Feature'].append(N[j])
            d['Frequency'].append('%.2f' % round(N[j].p(), 2))
        df = pd.DataFrame(d)
        outfile.write(df.to_csv(index=False))
        outfile.write('\n')

In [None]:
# #====================================================================================================
H = []
for i in range(msa.size):
    row = []
    for j in range(len(clusters)):
        count = 0
        for k in clusters[j]:
            if i in N[k].sequence_indices:
                count += 1
        row.append(float(count) / float(len(clusters[j])))
    H.append(row)
H = np.array(H)
H

In [None]:
# #====================================================================================================
Z = linkage(H, 'average')
fig = plt.figure(figsize=(25, 10))
dn = dendrogram(Z, labels=np.array(msa.headers))
plt.savefig('%s_dendrogram.png' % args.out)
tree = to_tree(Z, False)
with open('%s_dendrogram.nwk' % args.out, 'w') as outfile:
    outfile.write(get_newick(tree, "", tree.dist, msa.headers))

In [None]:
# #====================================================================================================
df = get_df(H, msa, range(msa.size), range(len(clusters)))
seq = df.pop('Seq. ID')
try:
    g = sns.clustermap(df, col_cluster=False, yticklabels=False, figsize=(4,4))
except SystemExit:
    raise 'Warning: few clusters to draw a heatmap!'
row_idx = g.dendrogram_row.reordered_ind
# col_idx = g.dendrogram_col.reordered_ind
col_idx = range(len(clusters))  # Keep column index without dendrogram
H = [H[i] for i in row_idx]
H = np.array(H)
df = get_df(H, msa, row_idx, col_idx)
df.to_csv('%s_seq_adhesion.csv' % args.out)
plt.savefig('%s_seq_adhesion.png' % args.out)

In [None]:
# #====================================================================================================
# # Optional viewing
# #====================================================================================================
mds = manifold.MDS(n_components=2, dissimilarity="precomputed", normalized_stress='auto')
pts = mds.fit(D).embedding_
clf = PCA(n_components=2)
pts = clf.fit_transform(pts)
# #====================================================================================================
colors = np.array(list(map(lambda x: x.p(), N))) * 100
_, axs = plt.subplots(1, 2, figsize=(12, 4))

# Plot 1: Residue Plot with noise
X_full, Y_full = zip(*pts)
sc = axs[0].scatter(X_full, Y_full, c=colors, cmap='rainbow', vmin=0., vmax=100., alpha=.5)
cb = plt.colorbar(sc, ax=axs[0])
cb.set_label('Frequency (%s)' % '%')
axs[0].set_title('Residue Plot with noise')

# #====================================================================================================
# Plot 2: Residue Plot without noise
noise = optics_instance.get_noise()
points, colors = [], []
for i, (p, c) in enumerate(zip(pts, list(map(lambda x: x.p(), N)))):
    if i not in noise:
        points.append(p)
        colors.append(c)
colors = np.array(colors) * 100
X_clean, Y_clean = zip(*points)
sc = axs[1].scatter(X_clean, Y_clean, c=colors, cmap='rainbow', vmin=0., vmax=100., alpha=0.5)
cb = plt.colorbar(sc, ax=axs[1])
cb.set_label('Frequency (%s)' % '%')
axs[1].set_title('Residue Plot without noise')

plt.savefig('%s_residue_plot_combined.png' % args.out)


In [None]:
# #====================================================================================================
# # END
# #====================================================================================================