Skip to content

Commit

Permalink
Add WIP hp model and hp model renderer with plotly
Browse files Browse the repository at this point in the history
  • Loading branch information
d53dave committed Nov 22, 2019
1 parent fa23fa8 commit bf56ad8
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 26 deletions.
75 changes: 49 additions & 26 deletions examples/hp/hp_opt.py
Expand Up @@ -4,15 +4,28 @@

from csaopt.utils import clamp

Monomer = Tuple[int, int, int, int]
Chain2d = List[Monomer]

# -- Globals

hp_str = 'PHHPPHPHPPHPHPHPPHPPHHHHH'
eps = -1
h_idxs = [idx for idx, mm in enumerate(hp_str) if mm == 'H']

# -- Globals

Chain2d = List[Tuple[int, int, int, int]]
# @numba.cuda.jit(inline=True, device=True)
def is_valid_conformation(chain: Chain2d) -> bool:
row_max = len(chain)
for i in range(row_max):
for j in range(i + 1, row_max):
d = (chain[i][1] - chain[j][1])**2 + (chain[i][2] - chain[j][2])**2
if d < 1.0:
return False
return True


# -- Globals


def empty_state() -> Collection:
Expand All @@ -31,7 +44,6 @@ def acceptance_func(e_old: float, e_new: float, temp: float, rnd: float) -> floa

def initialize(state: MutableSequence, randoms: Sequence[float]) -> None:
generate_next(state, state, randoms, 0) # just delegate to generate_next
return


def evaluate(state: Sequence) -> float:
Expand All @@ -57,33 +69,44 @@ def rigid_rotation(chain: Chain2d, idx: int = 0, clckwise: bool = False):

# Mutate the rest of the chain by the chosen rotation, starting from idx
for i in range(idx, len(chain)):
chain[i][3] = (chain[i][3] + rot) % 4
chain[i][3] = (chain[i][3] + rot) % 4 # type: ignore


def crankshaft(chain: Chain2d, idx: int):
tmp1 = chain[idx][3]
tmp2 = chain[idx + 1][3]
tmp2 = chain[idx + 2][3]
if tmp1 != tmp2:
chain[idx][3] = tmp2
chain[idx + 1][3] = tmp1

chain[idx][3] = tmp2 # type: ignore
chain[idx + 2][3] = tmp1 # type: ignore

def three_bead_flip(chain, idx):
pass


def generate_next(state: Sequence, new_state: MutableSequence, randoms: Sequence[float], step) -> Any:
idx = int(math.floor((len(state) - 1.0001) * randoms[0]))

for i in range(len(state)):
new_state[i] = state[i]

if randoms[1] < 0.3 or idx > (len(state) - 3):
# if the vec index is on the end, do an end flip
clckwise = randoms[2] < 0.5
rigid_rotation(new_state, idx, clckwise=clckwise)
elif randoms[1] < 0.66:
# do a three-bead flip, i.e. switch two adjacent {n,e,w,s} directions
crankshaft(new_state, idx)
else:
three_bead_flip(new_state, idx)
def three_bead_flip(chain: Chain2d, idx: int):
tmp1 = chain[idx][3]
tmp2 = chain[idx + 1][3]
if tmp1 != tmp2:
chain[idx][3] = tmp2 # type: ignore
chain[idx + 1][3] = tmp1 # type: ignore


def generate_next(state: Sequence, new_state: Chain2d, randoms: Sequence[float], step) -> Any:
len_randoms = len(randoms)
n = 0
while n <= 100:
idx = int(math.floor((len(state) - 1.0001) * randoms[n % len_randoms]))
for i in range(len(state)):
new_state[i] = state[i]

if randoms[1] < 0.3 or idx > (len(state) - 3):
# if the vec index is on the end, do an end flip
clckwise = randoms[2] < 0.5
rigid_rotation(new_state, idx, clckwise=clckwise)
elif randoms[1] < 0.66:
# do a three-bead flip, i.e. switch two adjacent {n,e,w,s} directions
crankshaft(new_state, idx)
else:
three_bead_flip(new_state, idx)

if is_valid_conformation(new_state):
break

n += 1
88 changes: 88 additions & 0 deletions examples/hp/render.py
@@ -0,0 +1,88 @@
import plotly
import plotly.graph_objs as go
import networkx as nx
import math
import multiprocessing

from typing import List, Tuple

Chain = List[List[int]]


def render_plotly(chain: Chain, contacts: List[Tuple[int, int]], filename='') -> None:
hp_len = len(chain)
scale_factor = max(8, hp_len / 4)

contact_edges = []
for contact in contacts:
print('Processing contact')
coord1 = chain[contact[0]]
coord2 = chain[contact[1]]
contact_edges.append(
go.Scatter(
hoverinfo='none',
x=[coord1[1], coord2[1]],
y=[coord1[2], coord2[2]],
line=dict(width=8, color='rgb(183,183,183,0.3)', dash='dot'),
))

Xbe = [coord[1] for coord in chain]
Ybe = [coord[2] for coord in chain]
backbone_edges = go.Scatter(
hoverinfo='none',
x=Xbe,
y=Ybe,
line=dict(width=2 * scale_factor, color='black'),
)

Xn = [coord[1] for coord in chain]
Yn = [coord[2] for coord in chain]
# for idx, coord in enumerate(coords):
# print('coord[{}][1] + coord[{}][2] = {}', idx, idx, coord[idx][1] + coord[idx][2])
print(Xn)
print(Yn)
node_trace = go.Scatter(
hoverinfo='none',
x=Xn,
y=Yn,
text=['<b>{}</b>'.format(i) for i in range(len(chain))],
textposition='middle center',
textfont=dict(size=2 * scale_factor, color='rgb(160,160,160)'),
line={},
mode='markers+text',
marker=dict(
showscale=False,
# colorscale options
# 'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
# 'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
# 'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
# colorscale='Blackbody',
reversescale=True,
color=[],
line=dict(width=scale_factor, color='black'),
size=8 * scale_factor,
))

node_trace['marker']['color'] = list(map(lambda c: 'white' if c[0] == 1 else 'rgb(70,70,70)', chain))

fig = go.Figure(
data=[
backbone_edges,
*contact_edges,
node_trace,
],
layout=go.Layout(
autosize=False,
width=1000,
height=1000,
showlegend=False,
xaxis=dict(showgrid=True, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=True, zeroline=False, showticklabels=False, scaleanchor="x", scaleratio=1)))

try:
current_proc = multiprocessing.current_process()
if filename == '':
filename = 'hp_plot_' + str(current_proc.pid)
except Exception:
pass
plotly.offline.plot(fig, filename=filename + '.html')

0 comments on commit bf56ad8

Please sign in to comment.